GALAX: Graph Learning with JAX
Graph Learning with JAX
Galax is a graph-centric, high-performance library for graph modeling with JAX.
Installation
> pip install g3x
Design principle
Pure JAX: end-to-end differentiable and jittable.
Graphs (including heterographs with multiple node types), metagraphs, and node and edge data are simply pytrees (or more precisely, namedtuples), and are thus immutable.
All transforms (including neural networks inhereted from flax) take and return graphs.
Grammar highly resembles DGL, except being purely functional.
Quick start
Implement a simple graph convolution in six lines.
>>> import jax.numpy as jnp; import galax
>>> g = galax.graph(([0, 1], [1, 2]))
>>> g = g.ndata.set("h", jnp.ones((3, 16)))
>>> g = g.update_all(galax.function.copy_u("h", "m"), galax.function.sum("m", "h"))
>>> W = jnp.random.normal(key=jax.random.PRNGKey(2666), shape=(16, 16))
>>> g = g.apply_nodes(lambda node: {"h": node.data["h"] @ W})