"""`Graph Convolutional network. <https://arxiv.org/abs/1609.02907>`__"""
from typing import Callable, Optional
import jax
import jax.numpy as jnp
from flax import linen as nn
from ... import function as fn
[docs]class GCN(nn.Module):
    r"""Graph convolutional layer from
    `Semi-Supervised Classification with Graph Convolutional
    Networks <https://arxiv.org/abs/1609.02907>`__
    Mathematically it is defined as follows:
    .. math::
      h_i^{(l+1)} =
      \sigma(b^{(l)} + \sum_{j\in\mathcal{N}(i)}
      \frac{1}{c_{ji}}h_j^{(l)}W^{(l)})
    where :math:`\mathcal{N}(i)` is the set of neighbors of node :math:`i`,
    :math:`c_{ji}` is the product of the square root of node degrees
    (i.e.,  :math:`c_{ji} = \sqrt{|\mathcal{N}(j)|}\sqrt{|\mathcal{N}(i)|}`),
    and :math:`\sigma` is an activation function.
    Parameters
    ----------
    out_feats : int
        Output features size.
    norm : Optional[str]
    Returns
    -------
    HeteroGraph
        The resulting Graph.
    Examples
    --------
    >>> import jax
    >>> import jax.numpy as jnp
    >>> import galax
    >>> g = galax.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
    >>> g = g.add_self_loop()
    >>> g = g.set_ndata("h", jnp.ones((6, 10)))
    >>> gcn = GCN(2, use_bias=True)
    >>> params = gcn.init(jax.random.PRNGKey(2666), g)
    >>> g = gcn.apply(params, g)
    >>> x = g.ndata['h']
    >>> x.shape
    (6, 2)
    """
    features: int
    use_bias: bool = False
    activation: Optional[Callable] = None
    @nn.compact
    def __call__(self, graph, field="h"):
        # initialize parameters
        kernel = self.param(
            'kernel',
            jax.nn.initializers.glorot_uniform(),
            (graph.ndata[field].shape[-1], self.features),
        )
        if self.use_bias:
            bias = self.param(
                "bias",
                jax.nn.initializers.zeros,
                (self.features, ),
            )
        else:
            bias = 0.0
        activation = self.activation
        if activation is None:
            activation = lambda x: x
        # propergate
        graph = graph.update_all(fn.copy_u(field, "m"), fn.sum("m", field))
        # normalize
        degrees = graph.out_degrees()
        norm = degrees ** (-0.5)
        norm_shape = norm.shape + (1, ) * (graph.ndata[field].ndim - 1)
        norm = jnp.reshape(norm, norm_shape)
        # transform
        function = lambda h: activation((norm * h @ kernel) * norm + bias)
        graph = fn.apply_nodes(function, field)(graph)
        return graph