Source code for assyst.perturbations

"""Classes to apply (random) perturbations to structures."""

import warnings
from abc import ABC, abstractmethod
from ase import Atoms
from typing import Iterable, Callable, Self, Iterator, Union
from dataclasses import dataclass
import numpy as np

from .filters import Filter
from .utils import update_uuid


[docs] def rattle( structure: Atoms, sigma: float, rng: Union[int, np.random.Generator, None] = None ) -> Atoms: """Randomly displace positions with gaussian noise. Operates INPLACE. Args: structure (:class:`ase.Atoms`): structure to perturb sigma (:class:`float`): standard deviation of the gaussian noise rng (:class:`int`, :class:`numpy.random.Generator`): seed or random number generator """ if len(structure) == 1: raise ValueError("Can only rattle structures larger than one atom.") if isinstance(rng, int): structure.rattle(stdev=sigma, seed=rng) else: if rng is None: rng = np.random structure.rattle(stdev=sigma, rng=rng) return structure
[docs] def element_scaled_rattle( structure: Atoms, sigma: float, reference: dict[str, float], rng: Union[int, np.random.Generator, None] = None, ) -> Atoms: """Randomly displace positions with gaussian noise relative to an elemental reference length. Operates like :func:`.rattle` but uses a standard deviation derived from the relative `sigma` and the `reference`, where this reference is given by element. Operates IN PLACE! Args: structure (:class:`ase.Atoms`): structure to perturb sigma (:class:`float`): relative standard deviation reference (:class:`dict` of :class:`str` to :class:`float`): reference length per element rng (:class:`int`, :class:`numpy.random.Generator`): seed or random number generator Raises: ValueError: if len(structure) == 1, create a super cell first before calling again ValueError: if reference values are not positive ValueError: if reference does not contain all elements in given structure """ sigma = sigma * np.ones(len(structure)) if not all(r > 0 for r in reference.values()): raise ValueError("Reference lengths must be strictly positive!") for i, sym in enumerate(structure.symbols): try: sigma[i] *= reference[sym] except KeyError: raise ValueError(f"No value for element {sym} provided in argument `reference`!") from None return rattle(structure, sigma.reshape(-1, 1), rng=rng)
def _ensure_perturbation(p: "Perturbation") -> "PerturbationABC": return p if isinstance(p, PerturbationABC) else FunctionPerturbation(p)
[docs] def stretch( structure: Atoms, hydro: float, shear: float, minimum_strain=1e-3, rng: Union[int, np.random.Generator, None] = None, ) -> Atoms: """Randomly stretch cell with uniform noise. Ensures at least `minimum_strain` strain to avoid structures very close to their original structures. These don't offer a lot of new information and can also confuse VASP's symmetry analyzer. Operates INPLACE. Args: structure (:class:`ase.Atoms`): structure to perturb hydro (:class:`float`): maximum hydrostatic (diagonal) strain magnitude shear (:class:`float`): maximum shear (off-diagonal) strain magnitude minimum_strain (:class:`float`): minimum strain magnitude to ensure structures differ from the original rng (:class:`int`, :class:`numpy.random.Generator`): seed or random number generator """ _rng = np.random.default_rng(rng) def get_strains(max_strain, size): if max_strain <= 0.0: return np.zeros(size) if 0 < max_strain < minimum_strain: warnings.warn( f"max_strain ({max_strain}) is smaller than minimum_strain ({minimum_strain}); " "minimum_strain floor cannot be enforced and will be ignored.", UserWarning, stacklevel=3, ) signs = _rng.choice([-1, 1], size=size) # clamp lower bound so the range is valid when minimum_strain > max_strain magnitudes = _rng.uniform(min(minimum_strain, max_strain), max_strain, size=size) return signs * magnitudes strain = np.zeros((3, 3)) # Off-diagonal elements indices = np.triu_indices(3, k=1) strain[indices] = get_strains(shear, 3) strain += strain.T # Diagonal elements np.fill_diagonal(strain, 1 + get_strains(hydro, 3)) structure.set_cell(structure.cell.array @ strain, scale_atoms=True) return structure
[docs] class PerturbationABC(ABC): """Apply some perturbation to a given structure.""" def __call__(self, structure: Atoms) -> Atoms: update_uuid(structure) if "perturbation" not in structure.info: structure.info["perturbation"] = str(self) else: structure.info["perturbation"] += "+" + str(self) return structure @abstractmethod def __str__(self) -> str: pass def __add__(self, other: Self) -> "Series": return Series((self, other))
Perturbation = Callable[[Atoms], Atoms] | PerturbationABC
[docs] def perturb( structures: Iterable[Atoms], perturbations: Iterable[Perturbation], filters: Iterable[Filter] | Filter | None = None, retries: int = 10, ) -> Iterator[Atoms]: """Apply a list of perturbations to each structure and yield the result of each perturbation separately. If a perturbation raises ValueError it is ignored. Args: structures: :class:`collections.abc.Iterable` of :class:`ase.Atoms` to perturb. perturbations: :class:`collections.abc.Iterable` of :class:`~.Perturbation` that modify structures. filters: :class:`collections.abc.Iterable` of :class:`~assyst.filters.Filter` to filter valid results (optional). retries: :class:`int`, max attempts per perturbation (default: 10). Yields: :class:`ase.Atoms`: perturbed structure that passes all filters. """ if filters is None: filters = [] elif not isinstance(filters, Iterable): filters = [filters] else: filters = list(filters) perturbations = [_ensure_perturbation(p) for p in perturbations] for structure in structures: for mod in perturbations: try: for _ in range(retries): m = mod(structure.copy()) if m is None: continue if all(f(m) for f in filters): yield m break except ValueError: continue
class RngMixin: """Mixin to handle RNG initialization.""" def __post_init__(self): object.__setattr__(self, "rng", np.random.default_rng(self.rng))
[docs] @dataclass(frozen=True) class Rattle(RngMixin, PerturbationABC): """Displace atoms by some absolute amount from a normal distribution.""" sigma: float create_supercells: bool = False "Create minimal 2x2x2 super cells when applied to structures of only one atom." rng: Union[int, np.random.Generator, None] = None def __call__(self, structure: Atoms): if self.create_supercells and len(structure) == 1: structure = structure.repeat(2) structure = super().__call__(structure) return rattle(structure, self.sigma, rng=self.rng) def __str__(self): return f"rattle({self.sigma})"
[docs] @dataclass(frozen=True) class ElementScaledRattle(RngMixin, PerturbationABC): """Displace atoms by some amount from a normal distribution. Operates like :class:`.Rattle` but uses a standard deviation derived from the relative `sigma` and the `reference`, where this reference is given by element. """ sigma: float reference: dict[str, float] create_supercells: bool = False "Create minimal 2x2x2 super cells when applied to structures of only one atom." rng: Union[int, np.random.Generator, None] = None def __call__(self, structure: Atoms): if self.create_supercells and len(structure) == 1: structure = structure.repeat(2) structure = super().__call__(structure) return element_scaled_rattle(structure, self.sigma, self.reference, rng=self.rng) def __str__(self): return f"scaled_rattle({self.sigma})"
[docs] @dataclass(frozen=True) class Stretch(RngMixin, PerturbationABC): """Apply random cell perturbation.""" hydro: float shear: float minimum_strain: float = 1e-3 rng: Union[int, np.random.Generator, None] = None def __call__(self, structure: Atoms): structure = super().__call__(structure) return stretch(structure, self.hydro, self.shear, self.minimum_strain, rng=self.rng) def __str__(self): return f"stretch(hydro={self.hydro}, shear={self.shear})"
@dataclass(frozen=True) class FunctionPerturbation(PerturbationABC): """Wrap a simple function into a PerturbationABC.""" func: Perturbation def __call__(self, structure: Atoms) -> Atoms: # call super to update uuid structure = super().__call__(structure) return self.func(structure) def __str__(self): return str(self.func)
[docs] @dataclass(frozen=True) class Series(PerturbationABC): """Apply some perturbations in sequence.""" perturbations: tuple[Perturbation, ...] def __post_init__(self): object.__setattr__( self, "perturbations", tuple(_ensure_perturbation(p) for p in self.perturbations), ) def __call__(self, structure: Atoms) -> Atoms | None: for mod in self.perturbations: structure = mod(structure) if structure is None: return None return structure def __str__(self): return "+".join(str(mod) for mod in self.perturbations)
[docs] @dataclass(frozen=True) class RandomChoice(RngMixin, PerturbationABC): """Apply either of two alternatives randomly.""" choice_a: Perturbation choice_b: Perturbation chance: float "Probability to pick choice b" rng: Union[int, np.random.Generator, None] = None def __post_init__(self): super().__post_init__() object.__setattr__(self, "choice_a", _ensure_perturbation(self.choice_a)) object.__setattr__(self, "choice_b", _ensure_perturbation(self.choice_b)) def __call__(self, structure: Atoms) -> Atoms: if self.rng.random() > self.chance: return self.choice_a(structure) else: return self.choice_b(structure) def __str__(self): return str(self.choice_a) + "|" + str(self.choice_b)
__all__ = [ "rattle", "element_scaled_rattle", "stretch", "PerturbationABC", "Perturbation", "perturb", "Rattle", "ElementScaledRattle", "Stretch", "Series", "RandomChoice", ]