galax.core.message_passing

galax.core.message_passing(graph: Any, mfunc: Optional[Callable], rfunc: Optional[galax.function.ReduceFunction], afunc: Optional[Callable] = None, etype: Optional[Callable] = None)[source]

Invoke message passing computation on the whole graph.

Parameters
  • g (HeteroGraph) – The input graph.

  • mfunc (Callable) – Message function.

  • rfunc (Callable) – Reduce function.

  • afunc (Callable) – Apply function.

Returns

The resulting graph.

Return type

HeteroGraph

Examples

>>> import galax
>>> import jax
>>> import jax.numpy as jnp
>>> g = galax.graph(((0, 1), (1, 2)))
>>> g = g.ndata.set("h", jnp.ones(3))
>>> mfunc = galax.function.copy_u("h", "m")
>>> rfunc = galax.function.sum("m", "h1")
>>> _g = message_passing(g, mfunc, rfunc)
>>> _g.ndata['h1'].flatten().tolist()
[0.0, 1.0, 1.0]