galax.function.apply_nodes

galax.function.apply_nodes(function: Callable, in_field: str = 'h', out_field: Optional[str] = None, ntype: Optional[str] = None)[source]

Apply a function to node attributes.

Parameters
  • function (Callable) – Input function.

  • in_field (str) – Input field

  • out_field (str) – Output field.

Returns

Function that takes and returns a graph.

Return type

Callable

Examples

Transform function. >>> import jax >>> import jax.numpy as jnp >>> import galax >>> graph = galax.graph(((0, 1), (1, 2))) >>> graph = graph.ndata.set(“h”, jnp.ones(3)) >>> fn = apply_nodes(lambda x: x * 2) >>> graph = jax.jit(fn)(graph) >>> graph.ndata[‘h’].tolist() [2.0, 2.0, 2.0]