metal_kernels
First custom Metal kernel for mlx_atomistic: a fused Lennard-Jones force kernel.
Collapses the per-step pairwise LJ force op-chain (gather -> minimum image -> r^2 ->
LJ scalar -> scatter-add) into a single mx.fast.metal_kernel dispatch. Forces use an
atomic scatter into a half neighbor list; per-pair energy is written to its own slot
(no contention) and summed by the caller — this keeps the energy accurate (needed by the
periodic-virial finite-difference path) without a single-cell energy-atomic hot spot.
Scope: pure-LJ reduced units, scalar epsilon/sigma, orthorhombic cell. Other cases
(Coulomb, triclinic cells, topology exclusions, the biomolecular NonbondedPotential)
stay on the MLX op-chain; callers fall back transparently.
Because tests/conftest.py forces MLX_ATOMISTIC_DEVICE=cpu, the kernel is built
lazily on first use (not at import) so importing this module never triggers a Metal
device load.
import mlx_atomistic.metal_kernels
Functions
Section titled “Functions”fused_lj_forces
Section titled “fused_lj_forces”def fused_lj_forces(positions: mx.array, pairs: mx.array, box_lengths: mx.array, *, epsilon: float, sigma: float, cutoff: float, shift: bool) -> tuple[mx.array, mx.array]Fused LJ energy + forces via a single Metal kernel (orthorhombic, scalar LJ).
Mirrors LennardJonesPotential._pair_energy_forces semantics: a half neighbor
list pairs of shape (M, 2), an r^2 cutoff mask, and an optional energy
shift at the cutoff. box_lengths are the orthorhombic edge lengths
(mx.diag(cell.matrix)). Returns (energy_scalar, forces) with forces
shape (N, 3).
Parameters
| Name | Type | Default | Description |
|---|---|---|---|
positions | mx.array | ||
pairs | mx.array | ||
box_lengths | mx.array | ||
epsilon | float | ||
sigma | float | ||
cutoff | float | ||
shift | bool |
Returns
tuple[mx.array, mx.array]