Skip to content

metal_kernels

First custom Metal kernel for mlx_atomistic: a fused Lennard-Jones force kernel.

Collapses the per-step pairwise LJ force op-chain (gather -> minimum image -> r^2 -> LJ scalar -> scatter-add) into a single mx.fast.metal_kernel dispatch. Forces use an atomic scatter into a half neighbor list; per-pair energy is written to its own slot (no contention) and summed by the caller — this keeps the energy accurate (needed by the periodic-virial finite-difference path) without a single-cell energy-atomic hot spot.

Scope: pure-LJ reduced units, scalar epsilon/sigma, orthorhombic cell. Other cases (Coulomb, triclinic cells, topology exclusions, the biomolecular NonbondedPotential) stay on the MLX op-chain; callers fall back transparently.

Because tests/conftest.py forces MLX_ATOMISTIC_DEVICE=cpu, the kernel is built lazily on first use (not at import) so importing this module never triggers a Metal device load.

import mlx_atomistic.metal_kernels

def fused_lj_forces(positions: mx.array, pairs: mx.array, box_lengths: mx.array, *, epsilon: float, sigma: float, cutoff: float, shift: bool) -> tuple[mx.array, mx.array]

Fused LJ energy + forces via a single Metal kernel (orthorhombic, scalar LJ).

Mirrors LennardJonesPotential._pair_energy_forces semantics: a half neighbor list pairs of shape (M, 2), an r^2 cutoff mask, and an optional energy shift at the cutoff. box_lengths are the orthorhombic edge lengths (mx.diag(cell.matrix)). Returns (energy_scalar, forces) with forces shape (N, 3).

Parameters

NameTypeDefaultDescription
positionsmx.array
pairsmx.array
box_lengthsmx.array
epsilonfloat
sigmafloat
cutofffloat
shiftbool

Returns

  • tuple[mx.array, mx.array]