galax.function.apply_edges

galax.function.apply_edges(function: Callable, in_field: str = 'h', out_field: Optional[str] = None, etype: Optional[str] = None)[source]

Apply a function to edge attributes.

Parameters
  • function (Callable) – Input function.

  • in_field (str) – Input field

  • out_field (str) – Output field.

Returns

Function that takes and returns a graph.

Return type

Callable

Examples

Transform function. >>> import jax >>> import jax.numpy as jnp >>> import galax >>> graph = galax.graph(((0, 1), (1, 2))) >>> graph = graph.edata.set(“h”, jnp.ones(2)) >>> fn = apply_edges(lambda x: x * 3) >>> graph = jax.jit(fn)(graph) >>> graph.edata[‘h’].tolist() [3.0, 3.0]