Source code for galax.function

"""Built-in functions."""
import sys
from typing import Optional, Callable
from functools import partial
from itertools import product
from collections import namedtuple
import jax
import jax.numpy as jnp
from flax.core import freeze, unfreeze
from flax import linen as nn

# =============================================================================
# MESSAGE FUNCTIONS
# =============================================================================
CODE2STR = {
    "u": "source",
    "v": "destination",
    "e": "edge",
}

CODE2OP = {
    "add": lambda x, y: x + y,
    "sub": lambda x, y: x - y,
    "mul": lambda x, y: x * y,
    "div": lambda x, y: x / y,
    "dot": lambda x, y: (x * y).sum(axis=-1, keepdims=True),
}

CODE2DATA = {
    "u": "src",
    "e": "data",
    "v": "dst",
}


[docs]def copy_u(u, out): """Builtin message function that computes message using source node feature. Parameters ---------- u : str The source feature field. out : str The output message field. """ return lambda edge: {out: edge.src[u]}
# create alias copy_src = copy_u
[docs]def copy_e(e, out): """Builtin message function that computes message using edge feature. Parameters ---------- e : str The edge feature field. out : str The output message field. """ return lambda edge: {out: edge.data[e]}
# create alias copy_edge = copy_e def _gen_message_builtin(lhs, rhs, binary_op): name = "{}_{}_{}".format(lhs, binary_op, rhs) docstring = """Builtin message function that computes a message on an edge by performing element-wise {} between features of {} and {} if the features have the same shape; otherwise, it first broadcasts the features to a new shape and performs the element-wise operation. Broadcasting follows NumPy semantics. Please see https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html for more details about the NumPy broadcasting semantics. Parameters ---------- lhs_field : str The feature field of {}. rhs_field : str The feature field of {}. out : str The output message field. """.format( binary_op, CODE2STR[lhs], CODE2STR[rhs], CODE2STR[lhs], CODE2STR[rhs], ) # grab data field lhs_data, rhs_data = CODE2DATA[lhs], CODE2DATA[rhs] # define function def func(lhs_field, rhs_field, out): def fn(edges): return {out: CODE2OP[binary_op]( getattr(edges, lhs_data)[lhs_field], getattr(edges, rhs_data)[rhs_field], ) } return fn # attach name and doc func.__name__ = name func.__doc__ = docstring return func def _register_builtin_message_func(): """Register builtin message functions""" target = ["u", "v", "e"] for lhs, rhs in product(target, target): if lhs != rhs: for binary_op in ["add", "sub", "mul", "div", "dot"]: func = _gen_message_builtin(lhs, rhs, binary_op) setattr(sys.modules[__name__], func.__name__, func) _register_builtin_message_func() # ============================================================================= # REDUCE FUNCTIONS # ============================================================================= ReduceFunction = namedtuple( "ReduceFunction", ["op", "msg_field", "out_field"] ) sum = partial(ReduceFunction, "sum") mean = partial(ReduceFunction, "mean") max = partial(ReduceFunction, "max") min = partial(ReduceFunction, "min") segment_sum = jax.ops.segment_sum
[docs]def segment_max(*args, **kwargs): """Alias of jax.ops.segment_max with nan_to_num.""" return jnp.nan_to_num( jax.ops.segment_max(*args, **kwargs), nan=0.0, posinf=0.0, neginf=0.0, )
[docs]def segment_min(*args, **kwargs): """Alias of jax.ops.segment_min with nan_to_num.""" return jnp.nan_to_num( jax.ops.segment_min(*args, **kwargs), nan=0.0, posinf=0.0, neginf=0.0, )
[docs]def segment_mean( data: jnp.ndarray, segment_ids: jnp.ndarray, num_segments: Optional[int] = None, indices_are_sorted: bool = False, unique_indices: bool = False, ): """Returns mean for each segment. Reference --------- * Shamelessly stolen from jraph.utils Parameters ---------- data : jnp.ndarray the values which are averaged segment-wise. segment_ids : jnp.ndarray indices for the segments. num_segments : Optional[int] total number of segments. indices_are_sorted : bool=False whether ``segment_ids`` is known to be sorted. unique_indices : bool=False whether ``segment_ids`` is known to be free of duplicates. Returns ------- jnp.ndarray The data after segmentation sum. """ nominator = segment_sum( data, segment_ids, num_segments, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, ) denominator = segment_sum( jnp.ones_like(data), segment_ids, num_segments, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, ) return nominator / jnp.maximum( denominator, jnp.ones(shape=[], dtype=denominator.dtype) )
[docs]def segment_softmax(data: jnp.ndarray, segment_ids: jnp.ndarray, num_segments: Optional[int] = None, indices_are_sorted: bool = False, unique_indices: bool = False): """Computes a segment-wise softmax. For a given tree of logits that can be divded into segments, computes a softmax over the segments. logits = jnp.ndarray([1.0, 2.0, 3.0, 1.0, 2.0]) segment_ids = jnp.ndarray([0, 0, 0, 1, 1]) segment_softmax(logits, segments) >> DeviceArray([0.09003057, 0.24472848, 0.66524094, 0.26894142, 0.7310586], >> dtype=float32) Args: logits: an array of logits to be segment softmaxed. segment_ids: an array with integer dtype that indicates the segments of `data` (along its leading axis) to be maxed over. Values can be repeated and need not be sorted. Values outside of the range [0, num_segments) are dropped and do not contribute to the result. num_segments: optional, an int with positive value indicating the number of segments. The default is ``jnp.maximum(jnp.max(segment_ids) + 1, jnp.max(-segment_ids))`` but since ``num_segments`` determines the size of the output, a static value must be provided to use ``segment_sum`` in a ``jit``-compiled function. indices_are_sorted: whether ``segment_ids`` is known to be sorted unique_indices: whether ``segment_ids`` is known to be free of duplicates Returns: The segment softmax-ed ``logits``. """ # First, subtract the segment max for numerical stability maxs = segment_max(data, segment_ids, num_segments, indices_are_sorted, unique_indices) logits = data - maxs[segment_ids] # Then take the exp logits = jnp.exp(logits) # Then calculate the normalizers normalizers = segment_sum(logits, segment_ids, num_segments, indices_are_sorted, unique_indices) normalizers = normalizers[segment_ids] softmax = logits / normalizers return softmax
# ============================================================================= # APPLY FUNCTIONS # =============================================================================
[docs]def apply_nodes( function: Callable, in_field: str = "h", out_field: Optional[str] = None, ntype: Optional[str] = None, ): """Apply a function to node attributes. Parameters ---------- function : Callable Input function. in_field : str Input field out_field : str Output field. Returns ------- Callable Function that takes and returns a graph. Examples -------- Transform function. >>> import jax >>> import jax.numpy as jnp >>> import galax >>> graph = galax.graph(((0, 1), (1, 2))) >>> graph = graph.ndata.set("h", jnp.ones(3)) >>> fn = apply_nodes(lambda x: x * 2) >>> graph = jax.jit(fn)(graph) >>> graph.ndata['h'].tolist() [2.0, 2.0, 2.0] """ if out_field is None: out_field = in_field def _fn(graph, in_field=in_field, out_field=out_field, ntype=ntype): ntype_idx = graph.get_ntype_id(ntype) node_frame = unfreeze(graph.node_frames[ntype_idx]) node_frame[out_field] = function(node_frame[in_field]) node_frame = freeze(node_frame) node_frames = graph.node_frames[:ntype_idx] + (node_frame,)\ + graph.node_frames[ntype_idx+1:] return graph._replace(node_frames=node_frames) def __fn(graph, in_field=in_field, out_field=out_field): graph = graph.ndata.set(out_field, function(graph.ndata[in_field])) return graph if ntype is None: return __fn else: return _fn
[docs]def apply_edges( function: Callable, in_field: str = "h", out_field: Optional[str] = None, etype: Optional[str] = None, ): """Apply a function to edge attributes. Parameters ---------- function : Callable Input function. in_field : str Input field out_field : str Output field. Returns ------- Callable Function that takes and returns a graph. Examples -------- Transform function. >>> import jax >>> import jax.numpy as jnp >>> import galax >>> graph = galax.graph(((0, 1), (1, 2))) >>> graph = graph.edata.set("h", jnp.ones(2)) >>> fn = apply_edges(lambda x: x * 3) >>> graph = jax.jit(fn)(graph) >>> graph.edata['h'].tolist() [3.0, 3.0] """ if out_field is None: out_field = in_field def _fn(graph, in_field=in_field, out_field=out_field, etype=etype): etype_idx = graph.get_etype_id(etype) edge_frame = unfreeze(graph.edge_frames[etype_idx]) edge_frame[out_field] = function(edge_frame[in_field]) edge_frame = freeze(edge_frame) edge_frames = graph.edge_frames[:etype_idx] + (edge_frame, )\ + graph.edge_frames[etype_idx+1:] return graph._replace(edge_frames=edge_frames) return _fn