"""Crystal structure generation step of ASSYST."""
from dataclasses import dataclass
from collections.abc import Sequence
from itertools import product, islice
from warnings import catch_warnings, warn
from typing import Self, Iterable, Iterator, Literal, overload, Union
import numpy as np
from .filters import DistanceFilter
from .utils import update_uuid
from ase import Atoms
from pyxtal import pyxtal as _pyxtal
from pyxtal.msg import Comp_CompatibilityError
from pyxtal.tolerance import Tol_matrix
from tqdm.auto import tqdm
def _get_real_spacegroup(s):
"""Use pyxtal to sniff spacegroup of generated crystals.
The pyxtal object does not update the space group after a call to from_random, but keeps the requested one around.
Since the generated atoms may actually be of higher symmetry, reinitialize here to make sure what we've got.
"""
p = _pyxtal()
p.from_seed(s.copy())
return p.group.number
[docs]
def pyxtal(
group: Union[int, list[int]],
species: tuple[str],
num_ions: tuple[int],
dim: Literal[0, 1, 2, 3] = 3,
repeat: int = 1,
allow_exceptions: bool = True,
rng: Union[int, np.random.Generator, None] = None,
**kwargs,
) -> Union[Atoms, list[dict]]:
"""
Generate random crystal structures with PyXtal.
`group` must be between 1 and the largest possible value for the given dimensionality:
dim=3 => 1 - 230 (space groups)
dim=2 => 1 - 80 (layer groups)
dim=1 => 1 - 75 (rod groups)
dim=0 => 1 - 58 (point groups)
When `group` is passed as a :class:`list` of :class:`int` or `repeat>1`, generate multiple structures and return them in a :class:`list`
of :class:`dict` containing the keys `atoms`, `symmetry` and `repeat` for the ASE structure, the symmetry group
number and which iteration it is, respectively.
Args:
group (list of int, or int): the symmetry group to generate or a list of them
species (tuple of str): which species to include, defines the stoichiometry together with `num_ions`
num_ions (tuple of int): how many of each species to include, defines the stoichiometry together with `species`
dim (int): dimensionality of the symmetry group, 0 is point groups, 1 is rod groups, 2 is layer groups and 3 is space groups
repeat (:class:`int`): how many random structures to generate
allow_exceptions (:class:`bool`): when generating multiple structures, silence errors when the requested stoichiometry and symmetry group are incompatible
rng (:class:`int`, :class:`numpy.random.Generator`): seed or random number generator
**kwargs: passed to `pyxtal.pyxtal` function verbatim
Returns:
:class:`ase.Atoms`: the generated structure, if repeat==1 and only one symmetry group is requested
:class:`list` of :class:`dict` of all generated structures, if repeat>1 or multiple symmetry groups are requested
Raises:
ValueError: if `species` and `num_ions` are not of the same length
ValueError: if stoichiometry and symmetry group are incompatible and allow_exceptions==False or only one structure is requested
"""
if len(species) != len(num_ions):
raise ValueError(
"species and num_ions must be of same length, "
f"not {species} and {num_ions}!"
)
stoich = "".join(f"{s}{n}" for s, n in zip(species, num_ions))
_rng = np.random.default_rng(rng)
def generate(group):
s = _pyxtal()
try:
s.from_random(
dim=dim,
group=group,
species=species,
numIons=num_ions,
random_state=_rng,
**kwargs,
)
except Comp_CompatibilityError:
if not allow_exceptions:
raise ValueError(
f"Symmetry group {group} incompatible with stoichiometry {stoich}!"
) from None
else:
return None
s = s.to_ase()
update_uuid(s)
s.wrap(center=(0, 0, 0))
return s
# return a single structure
if repeat == 1 and isinstance(group, int):
allow_exceptions = False
return generate(group)
else:
structures = []
if isinstance(group, int):
group = [group]
failed_groups = []
for g in tqdm(group, desc="Spacegroups"):
for i in range(repeat):
s = generate(g)
if s is None:
failed_groups.append(g)
continue
structures.append(
{
"atoms": s,
"symmetry": g,
"repeat": i,
"requested spacegroup": g,
"spacegroup": _get_real_spacegroup(s),
}
)
if len(failed_groups) > 0:
warn(
f"Groups [{', '.join(map(str, failed_groups))}] could not be generated with stoichiometry {stoich}!",
stacklevel=2,
)
return structures
[docs]
def sample(
formulas: Formulas | Iterable[dict[str, int]],
spacegroups: list[int] | tuple[int, ...] | Iterable[int] | None = None,
min_atoms: int = 1,
max_atoms: int = 10,
max_structures: int | None = None,
dim: Literal[0, 1, 2, 3] = 3,
tolerance: (
Literal["metallic", "atomic", "molecular", "vdW"] | DistanceFilter | dict
) = "metallic",
rng: Union[int, np.random.Generator, None] = None,
) -> Iterator[Atoms]:
"""
Create symmetric random structures.
Args:
formulas (:class:`.Formulas` or :class:`collections.abc.Iterable` of :class:`dict` from :class:`str` to :class:`int`): :class:`list` of chemical formulas
spacegroups (list of int): which space groups to generate
min_atoms (int): do not generate structures smaller than this
max_atoms (int): do not generate structures larger than this
max_structures (int): generate at most this many structures
dim (one of 0, 1, 2, or 3): the dimensionality of the structures to generate; if lower than 3 the code generates
samples no longer from space groups, but from the subperiodic layer, rod, or point groups.
tolerance (str, dict of elements to radii):
specifies minimum allowed distances between atoms in generated structures;
if str then it should be one values understood by :class:`pyxtal.tolerance.Tol_matrix`;
if dict each value gives the minimum *radius* allowed for an atom, whether a given distance is allowed then
depends on the sum of the radii of the respective elements
rng (:class:`int`, :class:`numpy.random.Generator`): seed or random number generator
Yields:
:class:`ase.Atoms`: random symmetric crystal structures
"""
if not 0 <= dim <= 3:
raise ValueError(f"dim must be in range [0, 3], not {dim}!")
# number of (sub-)periodic symmetry groups available in 0-3 dimensions
max_group = [58, 75, 80, 230][dim]
if spacegroups is None:
spacegroups = range(1, max_group + 1)
spacegroups = list(spacegroups)
min_spg = min(spacegroups)
max_spg = max(spacegroups)
if min_spg <= 0 or max_group < max_spg:
raise ValueError(
f"spacegroups must be in range [1, {max_group}], not [{min_spg}, {max_spg}] (dim={dim})!"
)
tm: Tol_matrix | None
match tolerance:
case "metallic" | "atomic" | "molecular" | "vdW":
tm = Tol_matrix(prototype=tolerance)
case dict():
tm = (
DistanceFilter(tolerance).to_tol_matrix()
if len(tolerance) > 0
else None
)
case DistanceFilter():
tm = tolerance.to_tol_matrix()
case _:
raise ValueError("invalid value tolerance={tolerance}!")
for stoich in (bar := tqdm(formulas)):
# pyxtal never returns structures when one element with zero atoms is present, so filter here first for
# robustness
stoich = {e: n for e, n in stoich.items() if n > 0}
if len(stoich) == 0:
continue
elements, num_atoms = zip(*stoich.items())
if not min_atoms <= sum(num_atoms) <= max_atoms:
continue
stoich_str = "".join(f"{s}{n}" for s, n in zip(elements, num_atoms))
bar.set_description(stoich_str)
def pop(s):
atoms = s.pop("atoms")
atoms.info.update(s)
return atoms
with catch_warnings(category=UserWarning, action="ignore"):
px = pyxtal(spacegroups, elements, num_atoms, dim=dim, tm=tm, rng=rng)
yield from islice(map(pop, px), max_structures)
if max_structures is not None:
max_structures -= len(px)
if max_structures <= 0:
break
__all__ = [
"pyxtal",
"Formulas",
"sample",
]