"""Classes that filter structures according to some criteria.
The code in the other modules that uses them is set up such that simple
functions can always be passed as well and that the classes here are just for
convenience."""
from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import dataclass, KW_ONLY
from itertools import combinations_with_replacement, product
from math import nan, inf
from numbers import Number
from typing import Callable, Literal
from ase import Atoms
from ase.calculators.singlepoint import SinglePointCalculator
from pyxtal.tolerance import Tol_matrix
from ase.data import atomic_numbers
import numpy as np
from assyst.neighbors import neighbor_list
[docs]
class FilterBase(ABC):
"""Base class for filter objects that implements conjunction and disjunction operators."""
def __and__(self, other) -> "AndFilter":
return AndFilter(self, other)
def __or__(self, other) -> "OrFilter":
return OrFilter(self, other)
@abstractmethod
def __call__(self, structure: Atoms) -> bool:
"""Returns True if structure passes the filter, False if it should be dropped."""
pass
Filter = Callable[[Atoms], bool] | FilterBase
[docs]
@dataclass(frozen=True, eq=True)
class AndFilter(FilterBase):
"""Conjunction of two filters."""
l: Filter
r: Filter
def __call__(self, structure: Atoms) -> bool:
return self.l(structure) and self.r(structure)
[docs]
@dataclass(frozen=True, eq=True)
class OrFilter(FilterBase):
"""Disjunction of two filters."""
l: Filter
r: Filter
def __call__(self, structure: Atoms) -> bool:
return self.l(structure) or self.r(structure)
[docs]
@dataclass
class DistanceFilter(FilterBase):
"""Filter structures that contain too close atoms.
Setting a radius to NaN allows all bonds involving this atom."""
radii: dict[str, float]
def __post_init__(self):
if isinstance(self.radii, Number):
r = self.radii
self.radii = defaultdict(lambda: r)
@staticmethod
def _element_wise_dist(structure: Atoms) -> dict[tuple[str, str], float]:
pair: dict[tuple[str, str], float] = defaultdict(lambda: inf)
for i, j, d in zip(*neighbor_list("ijd", structure, 5.0)):
ei, ej = sorted((structure.symbols[i], structure.symbols[j]))
pair[ei, ej] = min(d, pair[ei, ej])
return pair
def __call__(self, structure: Atoms) -> bool:
"""
Return True if structure satisfies minimum distance criteria.
Args:
structure (:class:`ase.Atoms`): structure to check
Returns:
`False`: at least one bond is shorter than the sum of given cutoff radii of the respective elements
`True`: all bonds are longer than the sum of given cutoff radii of the respective elements
"""
pair = self._element_wise_dist(structure)
for ei, ej in combinations_with_replacement(structure.symbols.species(), 2):
ei, ej = sorted((ei, ej))
if pair[ei, ej] < self.radii.get(ei, nan) + self.radii.get(ej, nan):
return False
return True
[docs]
def to_tol_matrix(
self, prototype: Literal["metallic", "atomic", "molecular", "vdW"] = "metallic"
) -> Tol_matrix:
"""Returns equivalent tolerance matrix for pyxtal.
Args:
prototype (metallic, atomic, molecular or vdW):
passed to Tol_matrix as is and used there to initialize radii of elements not explicitly set in this
filter
"""
return Tol_matrix(
*(
(
atomic_numbers[e1],
atomic_numbers[e2],
self.radii[e1] + self.radii[e2],
)
for e1, e2 in product(self.radii, repeat=2)
),
prototype=prototype,
)
[docs]
@dataclass
class AspectFilter(FilterBase):
"""Filters structures with high aspect ratios."""
maximum_aspect_ratio: float = 6
def __call__(self, structure: Atoms) -> bool:
"""Return True if structure's cell has an agreeable aspect ratio.
Args:
structure (:class:`ase.Atoms`): structure to check
Returns:
`True`: lattice's aspect ratio is below or equal `:attr:`.maximum_aspect_ratio`.
`False`: lattice's aspect ratio is above `:attr:`.maximum_aspect_ratio`."""
a, b, c = sorted(structure.cell.lengths())
return c / a <= self.maximum_aspect_ratio
[docs]
@dataclass
class VolumeFilter(FilterBase):
"""Filters structures by volume per atom (ų/atom).
Keeps structures whose volume per atom does not exceed ``maximum_volume_per_atom``.
"""
maximum_volume_per_atom: float
"""Upper bound on volume per atom in ų/atom."""
def __call__(self, structure: Atoms) -> bool:
"""Return True if structure's volume is within range.
Args:
structure (:class:`ase.Atoms`): structure to check
Returns:
`True`: volume per atom is smaller or equal than `:attr:.maximum_volume_per_atom`.
`False`: otherwise"""
return structure.cell.volume / len(structure) <= self.maximum_volume_per_atom
[docs]
@dataclass
class CalculatorFilter(FilterBase):
"""Filters that require a single point calculator set on the structure."""
_: KW_ONLY
missing: Literal["error", "ignore"] = "error"
"""Behaviour when a structure has no :class:`~ase.calculators.singlepoint.SinglePointCalculator` attached.
``"error"`` (default): raise :exc:`ValueError`.
``"ignore"``: silently pass the structure through (return ``True``).
"""
def _check(self, structure: Atoms) -> bool:
match self.missing:
case "error":
if structure.calc is None:
raise ValueError("Structure must have single point calculator set!")
if not isinstance(structure.calc, SinglePointCalculator):
raise ValueError(
f"Structure must have single point calculator set, not {type(structure.calc)}!"
)
return True
case "ignore":
return False
case _:
assert False
[docs]
@dataclass
class EnergyFilter(CalculatorFilter):
"""Filters structures by energy per atom (eV/atom).
Keeps structures whose energy per atom falls within ``[min_energy, max_energy]``.
"""
min_energy: float = -inf
"""Lower bound on energy per atom in eV/atom (default: −∞)."""
max_energy: float = +inf
"""Upper bound on energy per atom in eV/atom (default: +∞)."""
def __call__(self, structure: Atoms) -> bool:
if not self._check(structure):
return True
return (
self.min_energy
<= structure.get_potential_energy() / len(structure)
<= self.max_energy
)
[docs]
@dataclass
class ForceFilter(CalculatorFilter):
"""Filters structures by maximum atomic force magnitude (eV/Å).
Keeps structures where no atom experiences a force larger than ``max_force``.
"""
max_force: float = +inf
"""Maximum allowed force magnitude in eV/Å (default: +∞)."""
def __call__(self, structure: Atoms) -> bool:
if not self._check(structure):
return True
return np.linalg.norm(structure.get_forces(), axis=-1).max() <= self.max_force
__all__ = [
"FilterBase",
"Filter",
"AndFilter",
"OrFilter",
"DistanceFilter",
"AspectFilter",
"VolumeFilter",
"CalculatorFilter",
"EnergyFilter",
"ForceFilter"
]