galax.core.message_passing
- galax.core.message_passing(graph: Any, mfunc: Optional[Callable], rfunc: Optional[galax.function.ReduceFunction], afunc: Optional[Callable] = None, etype: Optional[Callable] = None)[source]
Invoke message passing computation on the whole graph.
- Parameters
g (HeteroGraph) – The input graph.
mfunc (Callable) – Message function.
rfunc (Callable) – Reduce function.
afunc (Callable) – Apply function.
- Returns
The resulting graph.
- Return type
Examples
>>> import galax >>> import jax >>> import jax.numpy as jnp >>> g = galax.graph(((0, 1), (1, 2))) >>> g = g.ndata.set("h", jnp.ones(3)) >>> mfunc = galax.function.copy_u("h", "m") >>> rfunc = galax.function.sum("m", "h1") >>> _g = message_passing(g, mfunc, rfunc) >>> _g.ndata['h1'].flatten().tolist() [0.0, 1.0, 1.0]