OMatG
OMatG copied to clipboard
State-of-the-art generative model for crystal structure prediction and de novo generation of inorganic crystals.
Open Materials Generation (OMatG)
A state-of-the-art generative model for crystal structure prediction and de novo generation of inorganic crystals. This open-source framework accompanies the ICML 2025 paper (also available on arXiv) which should be cited when using it.
Crystal structure prediction of GaTe:
De novo generation of Pd3Te2I3:
Table of Contents
- Overview.
- Installation.
- Included Datasets.
- Training.
- Generation.
- Visualization.
- Crystal Structure Prediction Metrics.
- De Novo Generation Metrics.
- Citing OMatG.
Overview
OMatG supports two crystal generation modes:
- Crystal structure prediction (CSP): Here, the atomic species are fixed during generation and only the fractional coordinates and lattice vectors are adapted to yield a stable crystal structure for the given composition.
- De novo generation (DNG): At the start of the generation, all atomic species are masked or randomly sampled. During the generation process, the species change together with the lattice vectors and fractional coordinates to yield a stable crystal structure.
OMatG leverages the command line interface of PyTorch Lightning for choosing the crystal generation mode, the interpolants, the dataset, and other hyperparameters. Typically, we recommend to use YAML files to store configurations (and sparsely use individual command line arguments to override some of the configuration parameters). This allows for easy reproducibility and sharing of configurations.
The omg/conf_examples directory contains some example configuration files. In
addition, we provide pretrained checkpoints of the models presented in the paper together with their configuration files
on Hugging Face.
A tutorial notebook for using OMatG including short exercises is available on Kaggle (solutions can be found here). Note that this notebook is part of a more general workshop on generative modeling and thus refers to generative models for fashion pieces. The relevant beginner-friendly notebook that introduces generative modeling with short coding exercises is also available on Kaggle (with solutions here).
Expand this section for a brief introduction to the theoretical background of OMatG.
Theoretical Background
OMatG implements the stochastic interpolants (SIs) framework for the modeling and generation of inorganic crystalline materials. SIs are a unifying framework for generative modeling that encompasses flow-matching and diffusion-based methods as specific instances, while offering a more general and flexible approach enabling the design of a broad class of novel generative models.
A stochastic interpolant $x_t = x(t, x_0, x_1, z)=\alpha(t),x_0 + \beta(t),x_1 + \gamma(t),z$ bridges samples $x_0$ from a (trivial) base distribution to samples $x_1$ from the target data distribution. Here, $t\in[0, 1]$ represents time and the random variable $z$ is drawn from a Gaussian distribution. The functional forms of $\alpha$, $\beta$, and $\gamma$ are flexible, only subject to a few constraints that, amongst other things, ensure that $x(t=0, x_0, x_1, z) = x_0$ and $x(t=1, x_0, x_1, z) = x_1$.
The time-dependent probability density of the stochastic process $x_t$ can be realized either via deterministic sampling through an ordinary differential equation (ODE) or stochastic sampling through a stochastic differential equation (SDE), only requiring a sample $x_0$ from the base distribution. This enables generative modeling by evolving samples from the base distribution to samples from the data distribution. Here, the required velocity term $b^\theta(t, x)$ for both ODE- and SDE-based sampling can be learned from data by "averaging over many paired samples ($x_0, x_1$)." For SDE-based sampling, an additional denoiser $z^\theta(t, x)$ can be learned likewise.
The flexibility of the SI framework stems from the ability to tailor the choice of interpolants and choosing between deterministic (ODE) and stochastic (SDE) sampling schemes (see Fig. 1 that visualizes the tunable components of the SI framework for bridging samples from a base distribution (gray particles) to samples from a target distribution (purple particles); figure taken from the OMatG paper.).
OMatG defines a crystalline material of $N$ atoms by its unit cell that is described by three lattice vectors $\mathbf{L} \in \mathbb{R}^{3\times3}$, $N$ fractional coordinates $\mathbf{X}\in[0,1)^{3\times N}$ with periodic boundary conditions, and $N$ discrete atomic species $\mathbf{A}\in\mathbb{Z}^N_{>0}$. During training an generation, all three components ${\mathbf{A}, \mathbf{X}, \mathbf{L}}$ are considered simultaneously. The SI framework is applied to the continuous structural representations ${\mathbf{X}, \mathbf{L}}$ while the discrete atomic species $\mathbf{A}$ are treated with discrete flow matching.
Expand this section for tips on how to set up new configuration files.
Configuration Files
Machine-learning models implemented with PyTorch Lightning rely on three essential parts:
Trainer: The training engine.LightningDataModule: Handles data loading and preprocessing.LightningModule: Defines the model and training logic.
Configuration files of OMatG thus generally contain specifications for these three parts.
Trainer
OMatG uses the standard PyTorch Lightning Trainer. Its
parameters are specified in the trainer section of the configuration file, for example:
trainer:
callbacks: # List of callbacks to be used during training.
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
filename: "best_val_loss_total"
save_top_k: 1
monitor: "val_loss_total"
save_weights_only: true
accelerator: "gpu"
gradient_clip_val: 0.5
gradient_clip_algorithm: "value"
num_sanity_val_steps: 0
precision: "32-true"
max_epochs: 2000
enable_progress_bar: true
Note that it is possible to initialize specialized classes in the configuration file by specifying the class_path and
init_args. The init_args dictionary contains the arguments that are passed to the constructor of the class.
In addition to the trainer, one should specify the optimizer and (optionally) the learning rate scheduler in their own sections:
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 0.001
weight_decay: 0.01
lr_scheduler:
class_path: torch.optim.lr_scheduler.CosineAnnealingLR
init_args:
T_max: 2000
eta_min: 1e-07
LightingDataModule
The data section of the configuration constructs the OMGDataModule (see
omg/datamodule/dataloader.py). It mainly expects the
train_dataset, val_dataset, and predict_dataset sections. Each of these sections should construct an
OMGTorchDataset (see omg/datamodule/dataloader.py again). This can be done based
on lmdb files:
data:
train_dataset:
class_path: omg.datamodule.dataloader.OMGTorchDataset
init_args:
dataset:
class_path: omg.datamodule.datamodule.DataModule
init_args:
lmdb_paths:
- "data/mp_20/train.lmdb"
niggli: False
val_dataset:
class_path: omg.datamodule.dataloader.OMGTorchDataset
init_args:
dataset:
class_path: omg.datamodule.datamodule.DataModule
init_args:
lmdb_paths:
- "data/mp_20/val.lmdb"
niggli: False
predict_dataset:
class_path: omg.datamodule.dataloader.OMGTorchDataset
init_args:
dataset:
class_path: omg.datamodule.datamodule.DataModule
init_args:
lmdb_paths:
- "data/mp_20/test.lmdb"
niggli: False
batch_size: 32
num_workers: 4
pin_memory: True
persistent_workers: True
Every record in the lmdb files should contain a crystal structure. The key of each record is assumed to be an (arbitrary) encoded string, while the value is assumed to be a pickled dictionary with, at least, the following keys:
pos: Atorch.Tensorof shape(N, 3)containing the fractional coordinates of the atoms in the crystal structure.cell: Atorch.Tensorof shape(3, 3)containing the lattice vectors of the crystal structure.atomic_numbers: Atorch.Tensorof shape(N,)containing the atomic numbers of the atoms in the crystal structure.
The data section can also contain additional parameters for the data loading (such as batch_size, num_workers,
pin_memory, and persistent_workers in the above example). These parameters are passed to the underlying
PyTorch DataLoader instances.
Within OMatG, the data is passed around as torch_geometric.data.Data instances. For a batch size of batch_size,
these instances contain the following attributes:
n_atoms:torch.Tensorof shape(batch_size, )containing the number of atoms in each configuration.batch:torch.Tensorof shape(sum(n_atoms),)containing the index of the configuration to which each atom belongs.species:torch.Tensorof shape(sum(n_atoms),)containing the atomic numbers of the atoms in the configurations.pos:torch.Tensorof shape(sum(n_atoms), 3)containing the atomic positions of the atoms in the configurations.cell:torch.Tensorof shape(batch_size, 3, 3)containing the cell vectors of the configurations.ptr:torch.Tensorof shape(batch_size + 1,)containing the indices of the first atom of each configuration in thespeciesandpostensors.property: dict containing the properties of the configurations.
LightningModule
The model section of the configuration file constructs the OMGLightningModule (see
omg/omg_lightning.py). Its arguments are documented in the class docstring.
An example model section looks as follows:
model:
si: # Collection of stochastic interpolants.
class_path: omg.si.stochastic_interpolants.StochasticInterpolants
init_args:
stochastic_interpolants:
# Chemical species.
# The SingleStochasticInterpolantIdentity keeps the species unchanged during interpolation (CSP task).
# For DNG, use, e.g., omg.si.discrete_flow_matching_mask.DiscreteFlowMatchingMask.
- class_path: omg.si.single_stochastic_interpolant_identity.SingleStochasticInterpolantIdentity
# Fractional coordinates.
- class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
init_args:
# Use a periodic interpolant for fractional coordinates.
interpolant: omg.si.interpolants.PeriodicLinearInterpolant
gamma: null
epsilon: null
differential_equation_type: "ODE"
integrator_kwargs:
method: "euler"
velocity_annealing_factor: 10.182659004291072
correct_center_of_mass_motion: true
# Lattice vectors.
- class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
init_args:
# Use a non-periodic interpolant for lattice vectors.
interpolant: omg.si.interpolants.LinearInterpolant
gamma: null
epsilon: null
differential_equation_type: "ODE"
integrator_kwargs:
method: "euler"
velocity_annealing_factor: 1.824475401606087
correct_center_of_mass_motion: false
data_fields:
# If the order of the data_fields changes,
# the order of the above StochasticInterpolant inputs must also change.
- "species"
- "pos"
- "cell"
integration_time_steps: 1000
relative_si_costs:
species_loss: 0.0
pos_loss_b: 0.999
cell_loss_b: 0.001
sampler:
class_path: omg.sampler.sample_from_rng.SampleFromRNG
init_args:
# Uniform distribution for fractional coordinates.
pos_distribution: null
cell_distribution:
class_path: omg.sampler.distributions.InformedLatticeDistribution
init_args:
dataset_name: mp_20
species_distribution:
# For DNG, use omg.sampler.distributions.MaskDistribution.
class_path: omg.sampler.distributions.MirrorData
model:
class_path: omg.model.model.Model
init_args:
encoder:
class_path: omg.model.encoders.cspnet_full.CSPNetFull
head:
class_path: omg.model.heads.pass_through.PassThrough
time_embedder:
class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
init_args:
dim: 256
The si section combines the stochastic interpolants for the species, pos, and cell data fields of the crystal
structures in the StochasticInterpolants class.
This class is documented in its docstring but, in a nutshell, it is a container for multiple
StochasticInterpolant instances. The typically used implementations of this abstract class are:
SingleStochasticInterpolant: For continuous data fields such as fractional coordinates and lattice vectors with arbitrary base distributions. The specific interpolant and its parameters are specified on initialization of this class. Every interpolant has a periodic (for fractional coordinates) and a non-periodic (for lattice vectors) version.SingleStochasticInterpolantOS: For continuous data fields such as fractional coordinates and lattice vectors, but explicitly assuming a Gaussian base distribution as it implements one-sided stochastic interpolants.SingleStochasticInterpolantIdentity: For keeping the corresponding data field unchanged during interpolation and generation.DiscreteFlowMatchingMask: For discrete data fields such as atomic species with a completely masked base distribution.DiscreteFlowMatchingUniform: For discrete data fields such as atomic species with a uniform base distribution.
Every StochasticInterpolant in the StochasticInterpolants class computes losses and returns them in a
dictionary (see the loss_keys method in the respective class). The StochasticInterpolants class prefixes these keys
with the name of the corresponding data field so that the losses can be identified. The relative_si_costs section
specifies the relative weights of these losses when they are added up during training.
The sampler section specifies the base distributions for the positions, lattice vectors, and atomic species. Depending
on the choice of the stochastic interpolant, one should choose the matching base distribution:
SingleStochasticInterpolant: The choice of the base distribution is arbitrary. As in the example above, we typically use a uniform distribution for the fractional coordinates and an informed base distribution for the lattice vectors.SingleStochasticInterpolantOS: Explicitly assumes aNormalDistribution.SingleStochasticInterpolantIdentity: Explicitly assumes that the training data is just taken over in the "random" sample as implemented by theMirrorDatadistribution.DiscreteFlowMatchingMask: Explicitly assumes fully masked samples as the base distribution as implemented in theMaskDistribution.DiscreteFlowMatchingUniform: Explicitly assumes uniformly distributed atomic species as the base distribution which can achieved by usingspecies_distribution: null.
The model section specifies the model architecture. In the above example, we just use DiffCSPNet.
Installation
Install the dependencies (see pyproject.toml) and the omg package itself by running
pip install . within the base directory of this repository. For editable mode (recommended for developers), use
pip install -e . instead. You can use any Python version between 3.10 and 3.13.
If the code in this repository changes, the command pip install . has to be executed again to also
change the code of the installed package. If you installed omg in editable mode, any changes in code are directly
available in the installed omg package.
NOTE: Installing PyTorch 2.7 based on the correct compute platform as described on the PyTorch webpage before installing
omgcan help minimize sources of installation errors. The same applies to PyG 2.6 and PyTorch Scatter 2.1.
Installing the omg package as described above provides the omg command for training, generation, and evaluation.
Included Datasets
For convenience, we include several material datasets that can be used for training. They can be found in the
omg/data directory and are described briefly below:
- MP-20: 45,229 structures from the Materials Project with a maximum of 20 atoms per structure.
- MPTS-52: Chronological data split of the Materials Project with 40,476 structures and up to 52 atoms per structure.
- Perov-5: A perovskite dataset containing 18,928 structures each with five atoms per structure.
- Carbon-24: A dataset of 10,153 structures consisting only of carbon with up to 24 atoms per structure.
- Alex-MP-20: New split of a consolidated dataset of 675,204 structures of Alexandria and MP-20 structures. In comparison to MatterGen's dataset, we removed 10% of the training data to create a test dataset. The Alex-MP-20 dataset is too large to be stored in this repository. We have made it available via the HuggingFace link associated with this project.
Training
Run the following command to train OMatG from scratch based on a configuration file:
omg fit --config=<configuration_file.yaml>
This command will create checkpoints, log files, and cache files in the current working directory. The configuration
file contain paths to lmdb files that are used, e.g., for training. The path
to these data files can either be relative to the working directory, or relative to the omg directory within this
repository (that is, use "data/mp_20/val.lmdb" in order to use the included mp_20 data set).
If you want to include a Wandb logger with a name, add the --trainer.logger=WandbLogger --trainer.logger.name=<name>
argument. Other loggers can be found here.
In order to restart training from a checkpoint, add the --ckpt_path=<checkpoint_file.ckpt> argument.
In order to seed the random number generators before training, use --seed_everything=<seed>.
Generation
For generating new structures in an xyz file based on a trained model, run the following command:
omg predict --config=<configuration_file.yaml> --ckpt_path=<checkpoint_file.ckpt> --model.generation_xyz_filename=<xyz_file>
This command will generate one epoch of structures, that is, the number of generated structures is equal to the number structures in the prediction dataset specified in the configuration file.
For an xyz filename filename.xyz, this command will also create a file filename_init.xyz that contains the initial
structures that were integrated to yield the structures in filename.xyz. This file is required for the visualization
below.
If you want to change the batch size of the generation, you can overwrite the batch size in the configuration file with
the --data.batch_size=<new_batch_size> argument.
Crystal-Structure Prediction of Specific Compositions
In order to predict crystal structures for specific compositions, the following command can be used to create an lmdb
file containing only dummy structures with the desired compositions:
omg create_compositions --config=<configuration_file.yaml> --compositions=<compositions> --lmdb_file=<lmdb_file>
Here, <compositions> is a composition string that can be understood by
PyMatgen's Composition class (e.g., --compositions='LiMn3O4') or a list
thereof (e.g., --compositions='[LiMn3O4, Ga4Te4]'). The optional repeats command line argument can be used
to repeat each composition multiple times in the created lmdb file (e.g., for generating multiple structures per
composition). By default, each composition is only included once.
The name of the created lmdb file is specified by <lmdb_file>. This lmdb file can then be used as the prediction
dataset in the configuration file for predicting structures with the desired compositions with the omg predict command
above (it is also possible to overwrite the prediction dataset on the command line with the
--data.predict_dataset.init_args.dataset.init_args.lmdb_paths=[<lmdb_file>] argument).
Importantly, one should use a checkpoint of a crystal-structure-prediction model whose training set includes the
elements of the desired compositions in the omg predict command.
Visualization
Run the following command to compare distributions over the generated structures to distributions over the training dataset:
omg visualize --config=<configuration_file.yaml> --xyz_file=<xyz_file> --plot_name=<plot_name.pdf>
Crystal Structure Prediction Metrics
Run the following command to compute the metrics for the CSP task:
omg csp_metrics --config=<configuration_file.yaml> --xyz_file=<xyz_file>
This command attempts to match structures at the same index in the generated dataset and the prediction dataset.
The metrics include the match rate between the generated structures and the structures in the prediction dataset, as
well as the average (normalized) root-mean square displacement between the matched structures. By default, these metrics
are stored in the csp_metrics.json file. This command also plots the histogram of the root-mean-square distances
between the matched structures in the rmsds.pdf file.
By default, this method first validates the generated structures and the structures in the prediction dataset
based on volume, structure, composition, and fingerprint checks (see ValidAtoms class),
and calculates the match rate between the valid generated structures and the valid structures in the prediction dataset.
The (slow) validation can be skipped by using skip_validation=True.
The validations and matchings are parallelized. The number of processes is determined by os.cpu_count(). This can
be changed by setting the --number_cpus argument (which is probably most useful in cluster environments).
Further arguments are documented in the csp_metrics method in the OMGTrainer class.
De Novo Generation Metrics
Run the following command to compute the metrics for the de novo generation task:
omg dng_metrics --config=<configuration_file.yaml> --xyz_file=<xyz_file> --dataset_name=<dataset_name>
The metrics include validity (structural and compositional) and Wasserstein distances between distributions of density,
volume fraction, number of atoms, number of unique elements, and average coordination number.
In addition, if dataset_name is set to mp_20, carbon_24, or perov_5, the metrics include coverage recall and
precision. By default, these metrics are stored in the dng_metrics.json file.
The validations are parallelized. The number of processes is determined by os.cpu_count(). This can
be changed by setting the --number_cpus argument (which is probably most useful in cluster environments).
Stability related metrics can be computed, for example, with the MatterGen codebase.
Citing OMatG
Please cite the following paper when using OMatG in your work:
@inproceedings{
hoellmer2025,
title={Open Materials Generation with Stochastic Interpolants},
author={Philipp H{\"o}llmer and Thomas Egg and Maya Martirossyan and Eric
Fuemmeler and Zeren Shui and Amit Gupta and Pawan Prakash and Adrian
Roitberg and Mingjie Liu and George Karypis and Mark Transtrum and Richard
Hennig and Ellad B. Tadmor and Stefano Martiniani},
booktitle={Forty-second International Conference on Machine Learning},
year={2025},
url={https://openreview.net/forum?id=gHGrzxFujU},
archivePrefix={arXiv},
eprint={2502.02582},
primaryClass={cs.LG},
}