Source code for assyst.filters

"""Classes that filter structures according to some criteria.

The code in the other modules that uses them is set up such that simple
functions can always be passed as well and that the classes here are just for
convenience."""

from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import dataclass, KW_ONLY
from itertools import combinations_with_replacement, product
from math import nan, inf
from numbers import Number
from typing import Callable, Literal

from ase import Atoms
from ase.calculators.singlepoint import SinglePointCalculator
from pyxtal.tolerance import Tol_matrix
from ase.data import atomic_numbers
import numpy as np

from assyst.neighbors import neighbor_list


[docs] class FilterBase(ABC): """Base class for filter objects that implements conjunction and disjunction operators.""" def __and__(self, other) -> "AndFilter": return AndFilter(self, other) def __or__(self, other) -> "OrFilter": return OrFilter(self, other) @abstractmethod def __call__(self, structure: Atoms) -> bool: """Returns True if structure passes the filter, False if it should be dropped.""" pass
Filter = Callable[[Atoms], bool] | FilterBase
[docs] @dataclass(frozen=True, eq=True) class AndFilter(FilterBase): """Conjunction of two filters.""" l: Filter r: Filter def __call__(self, structure: Atoms) -> bool: return self.l(structure) and self.r(structure)
[docs] @dataclass(frozen=True, eq=True) class OrFilter(FilterBase): """Disjunction of two filters.""" l: Filter r: Filter def __call__(self, structure: Atoms) -> bool: return self.l(structure) or self.r(structure)
[docs] @dataclass class DistanceFilter(FilterBase): """Filter structures that contain too close atoms. Setting a radius to NaN allows all bonds involving this atom.""" radii: dict[str, float] def __post_init__(self): if isinstance(self.radii, Number): r = self.radii self.radii = defaultdict(lambda: r) @staticmethod def _element_wise_dist(structure: Atoms) -> dict[tuple[str, str], float]: pair: dict[tuple[str, str], float] = defaultdict(lambda: inf) for i, j, d in zip(*neighbor_list("ijd", structure, 5.0)): ei, ej = sorted((structure.symbols[i], structure.symbols[j])) pair[ei, ej] = min(d, pair[ei, ej]) return pair def __call__(self, structure: Atoms) -> bool: """ Return True if structure satisfies minimum distance criteria. Args: structure (:class:`ase.Atoms`): structure to check Returns: `False`: at least one bond is shorter than the sum of given cutoff radii of the respective elements `True`: all bonds are longer than the sum of given cutoff radii of the respective elements """ pair = self._element_wise_dist(structure) for ei, ej in combinations_with_replacement(structure.symbols.species(), 2): ei, ej = sorted((ei, ej)) if pair[ei, ej] < self.radii.get(ei, nan) + self.radii.get(ej, nan): return False return True
[docs] def to_tol_matrix( self, prototype: Literal["metallic", "atomic", "molecular", "vdW"] = "metallic" ) -> Tol_matrix: """Returns equivalent tolerance matrix for pyxtal. Args: prototype (metallic, atomic, molecular or vdW): passed to Tol_matrix as is and used there to initialize radii of elements not explicitly set in this filter """ return Tol_matrix( *( ( atomic_numbers[e1], atomic_numbers[e2], self.radii[e1] + self.radii[e2], ) for e1, e2 in product(self.radii, repeat=2) ), prototype=prototype, )
[docs] @dataclass class AspectFilter(FilterBase): """Filters structures with high aspect ratios.""" maximum_aspect_ratio: float = 6 def __call__(self, structure: Atoms) -> bool: """Return True if structure's cell has an agreeable aspect ratio. Args: structure (:class:`ase.Atoms`): structure to check Returns: `True`: lattice's aspect ratio is below or equal `:attr:`.maximum_aspect_ratio`. `False`: lattice's aspect ratio is above `:attr:`.maximum_aspect_ratio`.""" a, b, c = sorted(structure.cell.lengths()) return c / a <= self.maximum_aspect_ratio
[docs] @dataclass(init=False) class VolumeFilter(FilterBase): """Filters structures by volume per atom (ų/atom). Keeps structures whose volume per atom falls within ``[minimum_volume_per_atom, maximum_volume_per_atom]``. Constructor follows :func:`range`-style semantics: a single positional argument is the upper bound, two positional arguments are ``(minimum, maximum)``. """ minimum_volume_per_atom: float = 0.0 """Lower bound on volume per atom in ų/atom (default: 0).""" maximum_volume_per_atom: float = +inf """Upper bound on volume per atom in ų/atom (default: +∞).""" def __init__( self, *args: float, minimum_volume_per_atom: float = 0.0, maximum_volume_per_atom: float = +inf, ): match args: case (): pass case (max_,): maximum_volume_per_atom = max_ case (min_, max_): minimum_volume_per_atom = min_ maximum_volume_per_atom = max_ case _: raise TypeError( f"VolumeFilter takes at most 2 positional arguments ({len(args)} given)" ) self.minimum_volume_per_atom = minimum_volume_per_atom self.maximum_volume_per_atom = maximum_volume_per_atom def __call__(self, structure: Atoms) -> bool: """Return True if structure's volume is within range. Args: structure (:class:`ase.Atoms`): structure to check Returns: `True`: volume per atom is within ``[:attr:`.minimum_volume_per_atom`, :attr:`.maximum_volume_per_atom`]``. `False`: otherwise""" volume_per_atom = structure.cell.volume / len(structure) return ( self.minimum_volume_per_atom <= volume_per_atom <= self.maximum_volume_per_atom )
[docs] @dataclass class CalculatorFilter(FilterBase): """Filters that require a single point calculator set on the structure.""" _: KW_ONLY missing: Literal["error", "ignore"] = "error" """Behaviour when a structure has no :class:`~ase.calculators.singlepoint.SinglePointCalculator` attached. ``"error"`` (default): raise :exc:`ValueError`. ``"ignore"``: silently pass the structure through (return ``True``). """ def _check(self, structure: Atoms) -> bool: match self.missing: case "error": if structure.calc is None: raise ValueError("Structure must have single point calculator set!") if not isinstance(structure.calc, SinglePointCalculator): raise ValueError( f"Structure must have single point calculator set, not {type(structure.calc)}!" ) return True case "ignore": return False case _: assert False
[docs] @dataclass class EnergyFilter(CalculatorFilter): """Filters structures by energy per atom (eV/atom). Keeps structures whose energy per atom falls within ``[min_energy, max_energy]``. """ min_energy: float = -inf """Lower bound on energy per atom in eV/atom (default: −∞).""" max_energy: float = +inf """Upper bound on energy per atom in eV/atom (default: +∞).""" def __call__(self, structure: Atoms) -> bool: if not self._check(structure): return True return ( self.min_energy <= structure.get_potential_energy() / len(structure) <= self.max_energy )
[docs] @dataclass class ForceFilter(CalculatorFilter): """Filters structures by maximum atomic force magnitude (eV/Å). Keeps structures where no atom experiences a force larger than ``max_force``. """ max_force: float = +inf """Maximum allowed force magnitude in eV/Å (default: +∞).""" def __call__(self, structure: Atoms) -> bool: if not self._check(structure): return True return np.linalg.norm(structure.get_forces(), axis=-1).max() <= self.max_force
__all__ = [ "FilterBase", "Filter", "AndFilter", "OrFilter", "DistanceFilter", "AspectFilter", "VolumeFilter", "CalculatorFilter", "EnergyFilter", "ForceFilter" ]