Source code for galax.nn.zoo.graphsage

"""`GraphSAGE <https://cs.stanford.edu/~jure/pubs/graphsage-nips17.pdf>`__"""

from typing import Callable, Optional
import jax
import jax.numpy as jnp
from flax import linen as nn
from ... import function as fn

[docs]class GraphSAGE(nn.Module): r"""GraphSAGE layer from `Inductive Representation Learning on Large Graphs <https://arxiv.org/pdf/1706.02216.pdf>`__ .. math:: h_{\mathcal{N}(i)}^{(l+1)} &= \mathrm{aggregate} \left(\{h_{j}^{l}, \forall j \in \mathcal{N}(i) \}\right) h_{i}^{(l+1)} &= \sigma \left(W \cdot \mathrm{concat} (h_{i}^{l}, h_{\mathcal{N}(i)}^{l+1}) \right) h_{i}^{(l+1)} &= \mathrm{norm}(h_{i}^{(l+1)}) Parameters ---------- in_feats : int, or pair of ints Input feature size; i.e, the number of dimensions of :math:`h_i^{(l)}`. If aggregator type is ``gcn``, the feature size of source and destination nodes are required to be the same. out_feats : int Output feature size; i.e, the number of dimensions of :math:`h_i^{(l+1)}`. aggregator_type : str Aggregator type to use (``mean``, ``gcn``, ``pool``, ``lstm``). feat_drop : float Dropout rate on features, default: ``0``. use_bias : bool If True, adds a learnable bias to the output. Default: ``True``. norm : callable activation function/layer or None, optional If not None, applies normalization to the updated node features. activation : callable activation function/layer or None, optional If not None, applies an activation function to the updated node features. Default: ``None``. Examples -------- >>> import jax >>> import jax.numpy as jnp >>> import galax >>> g = galax.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])) >>> g = g.add_self_loop() >>> g = g.set_ndata("h", jnp.ones((6, 10))) >>> graphsage = GraphSAGE(2, "pool") >>> params = graphsage.init(jax.random.PRNGKey(2666), g) >>> x = graphsage.apply(params, g).ndata["h"] >>> x.shape (6, 2) """ features: int aggregator_type: str = "mean" use_bias: bool = True activation: Optional[Callable] = None @nn.compact def __call__(self, graph, field="h"): h_self = graph.ndata[field] if self.aggregator_type == "mean": graph = graph.update_all( fn.copy_src(field, "m"), fn.mean("m", "neigh"), ) h_neigh = graph.ndata["neigh"] elif self.aggregator_type == "gcn": graph = graph.update_all( fn.copy_src(field, "m"), fn.sum("m", "neigh"), ) degrees = graph.in_degrees() h_neigh = (graph.ndata["neigh"] + graph.ndata[field])\ / (jnp.expand_dims(degrees, -1) + 1) elif self.aggregator_type == "pool": h_pool = jax.nn.relu( nn.Dense(graph.ndata[field].shape[-1])(graph.ndata[field]), ) graph = graph.ndata.set(field, h_pool) graph = graph.update_all( fn.copy_src(field, "m"), fn.max("m", "neigh"), ) h_neigh = graph.ndata["neigh"] h_neigh = nn.Dense(self.features, use_bias=False)(h_neigh) if self.aggregator_type == "gcn": rst = h_neigh else: rst = h_neigh + nn.Dense(self.features, use_bias=False)(h_self) if self.activation is not None: rst = self.activation(rst) graph = graph.ndata.set(field, rst) return graph