Source code for nodefinder.search.result._search_result_container

# -*- coding: utf-8 -*-

# © 2017-2019, ETH Zurich, Institut für Theoretische Physik
# Author: Dominik Gresch <greschd@gmx.ch>
"""
Defines the top-level result class for the search step.
"""

import numpy as np
from fsc.export import export
from fsc.hdf5_io import SimpleHDF5Mapping, subscribe_hdf5

from ._cell_list import CellList


[docs]@export @subscribe_hdf5( 'nodefinder.search_result_container', extra_tags=['nodefinder.result_container'] ) class SearchResultContainer(SimpleHDF5Mapping): """ Container for the results of a search run. Attributes ---------- coordinate_system : CoordinateSystem Coordinate system used. nodes : list(MinimizationResult) Minimization results which fulfill the gap threshold criterion. gap_threshold : float Threshold for results to be considered a node. dist_cutoff : float Cutoff distance for searching neighbouring nodes. """ HDF5_ATTRIBUTES = [ 'coordinate_system', 'minimization_results', 'dist_cutoff', 'gap_threshold', 'refined_results' ] HDF5_OPTIONAL = ['refined_results'] def __init__( self, *, coordinate_system, minimization_results=(), gap_threshold, dist_cutoff, refined_results=() ): self.coordinate_system = coordinate_system self.gap_threshold = gap_threshold self.dist_cutoff = dist_cutoff if dist_cutoff == 0: num_cells = np.full_like(self.coordinate_system.size, 100) else: num_cells = np.minimum( # pylint: disable=assignment-from-no-return,useless-suppression 100, np.maximum( 1, np.array( self.coordinate_system.size / self.dist_cutoff, dtype=int ) ) ) self.nodes = CellList( num_cells=num_cells, periodic=self.coordinate_system.periodic ) self.rejected_results = [] for res in minimization_results: self.add_result(res) self.refined_results = CellList( num_cells=num_cells, periodic=self.coordinate_system.periodic ) for res in refined_results: self.set_refined(res) self.needs_saving = True def __repr__(self): return 'SearchResultContainer(coordinate_system={0.coordinate_system}, minimization_results=<{1} values>, gap_threshold={0.gap_threshold!r}, dist_cutoff={0.dist_cutoff!r})'.format( self, len(self.minimization_results) )
[docs] def add_result(self, res): """ Add a minimization result to the container. Arguments --------- res : MinimizationResult Minimization result to add. """ res.pos = self.coordinate_system.normalize_position(res.pos) if not res.success or res.value > self.gap_threshold: # pylint: disable=no-else-return self.rejected_results.append(res) return False else: self.nodes.add_point(self.coordinate_system.get_frac(res.pos), res) return True self.needs_saving = True
[docs] def set_refined(self, pos): """ Set a position to be refined. Arguments --------- pos : np.array The position from where refinement started. """ self.refined_results.add_point( self.coordinate_system.get_frac(pos), pos ) self.needs_saving = True
@property def minimization_results(self): """ list(MinimizationResult): All minimization results, including rejected points. """ return self.nodes.values() + self.rejected_results def _get_neighbour_iterator(self, pos): candidates = self.nodes.get_neighbour_values( frac=self.coordinate_system.get_frac(pos) ) return (c for c in candidates if np.any(c.pos != pos))
[docs] def get_neighbour_distance_iterator(self, pos): """ Returns an iterator over the distance to neighbouring nodes from a given position. Only neighbours within ``dist_cutoff`` are taken into account. Arguments --------- pos : numpy.ndarray Position for which to calculate the distances. """ candidates = self._get_neighbour_iterator(pos) return ( self.coordinate_system.distance(pos, c.pos) for c in candidates )
[docs] def get_refined_neighbour_distance_iterator(self, pos): # pylint: disable=invalid-name """ Returns an iterator over the distance to neighboring nodes which have been used as a starting point in a refinement procedure. Arguments --------- pos : numpy.ndarray Position for which to calculate the distances. """ candidates = self.refined_results.get_neighbour_values( frac=self.coordinate_system.get_frac(pos) ) return (self.coordinate_system.distance(pos, c) for c in candidates)
[docs] def get_all_neighbour_distances(self, pos): """ Calculate the distances to neighbouring nodes from a given position. Only neighbours within ``dist_cutoff`` are taken into account. Arguments --------- pos : numpy.ndarray Position for which to calculate the distances. """ candidates = self._get_neighbour_iterator(pos) positions = np.array([c.pos for c in candidates]) if positions.size == 0: return [] return self.coordinate_system.distance(pos, positions)