GALAX: Graph Learning with JAX

Graph Learning with JAX

CI Language grade: Python pypi docs stable

Galax is a graph-centric, high-performance library for graph modeling with JAX.

Installation

> pip install g3x

Design principle

  • Pure JAX: end-to-end differentiable and jittable.

  • Graphs (including heterographs with multiple node types), metagraphs, and node and edge data are simply pytrees (or more precisely, namedtuples), and are thus immutable.

  • All transforms (including neural networks inhereted from flax) take and return graphs.

  • Grammar highly resembles DGL, except being purely functional.

Quick start

Implement a simple graph convolution in six lines.

>>> import jax.numpy as jnp; import galax
>>> g = galax.graph(([0, 1], [1, 2]))
>>> g = g.ndata.set("h", jnp.ones((3, 16)))
>>> g = g.update_all(galax.function.copy_u("h", "m"), galax.function.sum("m", "h"))
>>> W = jnp.random.normal(key=jax.random.PRNGKey(2666), shape=(16, 16))
>>> g = g.apply_nodes(lambda node: {"h": node.data["h"] @ W})

Indices and tables