"""Helper plotting functions."""
from typing import Literal, Callable, Iterable
from collections import Counter, defaultdict
from ase import Atoms
import matplotlib.pyplot as plt
import numpy as np
from assyst.neighbors import neighbor_list
def _volume(structures: Iterable[Atoms]) -> list[float]:
return [s.cell.volume / len(s) for s in structures]
def _energy(structures: Iterable[Atoms]) -> list[float]:
return [s.get_potential_energy() / len(s) for s in structures]
def _concentration(
structures: Iterable[Atoms], elements: Iterable[str] | None = None
) -> list[dict[str, float]]:
structure_concentrations = [
{k: v / len(s) for k, v in Counter(s.symbols).items()} for s in structures
]
concentrations = defaultdict(lambda: np.zeros(len(structure_concentrations)))
for i, d in enumerate(structure_concentrations):
for e, c in d.items():
concentrations[e][i] = c
if elements is not None:
concentrations = {e: concentrations[e] for e in elements}
return concentrations
def _distance(
structures: Iterable[Atoms], rmax: float
) -> list[Iterable[float]]:
return [neighbor_list("d", s, float(rmax)) for s in structures]
def _plot_histogram(
structures: Iterable[Atoms],
extractor: Callable[[Iterable[Atoms]], Iterable[float]],
xlabel: str,
ylabel: str,
**kwargs
):
"""Helper function to plot histograms.
Args:
structures (iterable of :class:`ase.Atoms`):
structures to plot
extractor (callable):
function to extract data from structures
xlabel (str):
label for x-axis
ylabel (str):
label for y-axis
**kwargs:
passed through to :func:`matplotlib.pyplot.hist`
Returns:
Return value of :func:`matplotlib.pyplot.hist`
"""
data = extractor(structures)
res = plt.hist(data, **kwargs)
plt.xlabel(xlabel)
plt.ylabel(ylabel)
return res
[docs]
def volume_histogram(structures: list[Atoms], **kwargs):
"""Plot histogram of per-atom volumes.
Args:
structures (list of :class:`ase.Atoms`):
structures to plot
**kwargs:
passed through to :func:`matplotlib.pyplot.hist`
Returns:
Return value of :func:`matplotlib.pyplot.hist`"""
return _plot_histogram(
structures,
_volume,
r"Volume [$\mathrm{\AA}^3/\mathrm{atom}$]",
r"#$\,$Structures",
**kwargs
)
[docs]
def size_histogram(structures: list[Atoms], **kwargs):
"""Plot histogram of number of atoms.
Args:
structures (list of :class:`ase.Atoms`):
structures to plot
**kwargs:
passed through to :func:`matplotlib.pyplot.hist`
Returns:
Return value of :func:`matplotlib.pyplot.hist`"""
return _plot_histogram(
structures,
lambda s: list(map(len, s)),
"# Atoms",
r"#$\,$Structures",
**kwargs
)
[docs]
def concentration_histogram(
structures: list[Atoms], elements: Iterable[str] | None = None, **kwargs
):
"""Plot histogram of concentrations.
Args:
structures (list of :class:`ase.Atoms`):
structures to plot
elements (iterable of str):
which element concentrations to plot, by default all present
**kwargs:
passed through to :func:`matplotlib.pyplot.bar`"""
conc = _concentration(structures, elements=elements)
conc_step = np.diff(
sorted(np.unique(np.concatenate([np.unique(c) for c in conc.values()])))
).min()
kwargs.setdefault("width", conc_step)
width = kwargs["width"]
kwargs["width"] = width / len(conc)
shifts = np.linspace(0, 1, len(conc), endpoint=False)
for i, (e, c) in enumerate(conc.items()):
x, h = np.unique(c, return_counts=True)
plt.bar(x + shifts[i] * width - width / 2, h, label=e, align="edge", **kwargs)
plt.legend()
plt.xlabel("Concentration")
plt.ylabel("#$\\,$Structures")
[docs]
def distance_histogram(
structures: list[Atoms],
rmax: float = 6.0,
reduce: Literal["min", "mean"] | Callable[[Iterable[float]], float] | None = "min",
**kwargs,
):
"""Plot histogram of per-atom volumes.
Args:
structures (list of :class:`ase.Atoms`):
structures to plot
rmax (float):
maximum cutoff to consider neighborhood
reduce (callable from array of floats to float):
applied to the neighbor distances per structure, and should reduce a single scalar that is binned
**kwargs:
passed through to :func:`matplotlib.pyplot.hist`
Returns:
Return value of :func:`matplotlib.pyplot.hist`"""
kwargs.setdefault("bins", 100)
labels = {
"min": r"Minimum distance [$\mathrm{\AA}$]",
"mean": r"Mean distance [$\mathrm{\AA}$]",
}
xlabel = labels.get(reduce, r"Distance [$\mathrm{\AA}$]")
_preset = {
"min": np.min,
"mean": np.mean,
}
if reduce is None:
def extractor(s):
return np.concatenate(
[neighbor_list("d", struct, float(rmax)) for struct in s]
)
ylabel = r"#$\,$Neighbours"
else:
reduce_func = _preset.get(reduce, reduce)
def extractor(s):
return [reduce_func(neighbor_list("d", struct, float(rmax))) for struct in s]
ylabel = r"#$\,$Structures"
return _plot_histogram(structures, extractor, xlabel, ylabel, **kwargs)
[docs]
def radial_distribution(
structures: list[Atoms],
rmax: float = 6.0,
**kwargs
):
"""Plot radial distribution of neighbors in training set.
Calculates all neighbors in all structures and histograms them together.
Bins are weighted by 1/(4 pi r^2), but because the density in each
structure can be different, the plot does *not* yield something that can be
directly compared to a Radial Distribution Function.
It can be used to locate prefered bonding distances or sampling of the
radial neighborhood in a training set given suitable data.
Args:
structures (list of :class:`ase.Atoms`):
structures to plot
rmax (float):
maximum cutoff to consider neighborhood
**kwargs: pass through to :func:`matplotlib.pyplot.hist`
Returns:
Return value of :func:`matplotlib.pyplot.hist`"""
kwargs.setdefault("bins", 100)
distances = np.concatenate([n for n in _distance(structures, rmax)])
weights = 1 / (4 * np.pi * distances ** 2)
res = plt.hist(distances, weights=weights, **kwargs)
plt.xlabel(r"Distance [$\mathrm{\AA}$]")
plt.ylabel("Radial distribution")
return res
[docs]
def energy_histogram(
structures: list[Atoms],
**kwargs
):
"""Plot energy per atom histogram.
Requires that :class:`ase.calculators.singlepoint.SinglePointCalculator` are attached to the atoms, either from a
relaxation for final training set calculation.
Args:
structures (list of :class:`ase.Atoms`): structures to plot
**kwargs: pass through to :func:`matplotlib.pyplot.hist`
Returns:
Return value of :func:`matplotlib.pyplot.hist`"""
kwargs.setdefault("bins", 100)
return _plot_histogram(
structures,
_energy,
r"Energy [eV/atom]",
r"#$\,$Structures",
**kwargs
)
[docs]
def energy_volume(structures: list[Atoms], **kwargs):
"""Plot energy per atom versus volume per atom.
Requires that :class:`ase.calculators.singlepoint.SinglePointCalculator` are attached to the atoms, either from a
relaxation for final training set calculation.
Args:
structure: list[Atoms],
structures to plot"""
V = _volume(structures)
E = _energy(structures)
structures = list(structures)
if len(structures) < 1000:
if "s" not in kwargs and "markersize" not in kwargs:
kwargs["markersize"] = 5
plt.scatter(V, E, **kwargs)
else:
plt.hexbin(V, E, **kwargs, bins="log")
plt.xlabel(r"Volume [$\mathrm{\AA}^3/\mathrm{atom}$]")
plt.ylabel(r"Energy [eV/atom]")
__all__ = [
"volume_histogram",
"size_histogram",
"concentration_histogram",
"distance_histogram",
"radial_distribution",
"energy_histogram",
"energy_volume",
]