galax.function.apply_edges
- galax.function.apply_edges(function: Callable, in_field: str = 'h', out_field: Optional[str] = None, etype: Optional[str] = None)[source]
Apply a function to edge attributes.
- Parameters
function (Callable) – Input function.
in_field (str) – Input field
out_field (str) – Output field.
- Returns
Function that takes and returns a graph.
- Return type
Callable
Examples
Transform function. >>> import jax >>> import jax.numpy as jnp >>> import galax >>> graph = galax.graph(((0, 1), (1, 2))) >>> graph = graph.edata.set(“h”, jnp.ones(2)) >>> fn = apply_edges(lambda x: x * 3) >>> graph = jax.jit(fn)(graph) >>> graph.edata[‘h’].tolist() [3.0, 3.0]