"""`Graph Attention Network <https://arxiv.org/pdf/1710.10903.pdf>`__"""
from typing import Callable, Optional
import jax
import jax.numpy as jnp
from flax import linen as nn
from ... import function as fn
[docs]class GAT(nn.Module):
r"""
Apply `Graph Attention Network <https://arxiv.org/pdf/1710.10903.pdf>`__
over an input signal.
.. math::
h_i^{(l+1)} = \sum_{j\in \mathcal{N}(i)} \alpha_{i,j} W^{(l)} h_j^{(l)}
where :math:`\alpha_{ij}` is the attention score bewteen node :math:`i` and
node :math:`j`:
.. math::
\alpha_{ij}^{l} &= \mathrm{softmax_i} (e_{ij}^{l})
e_{ij}^{l} &=
\mathrm{LeakyReLU}\left(\vec{a}^T [W h_{i} \| W h_{j}]\right)
Parameters
----------
features : int
Features
num_heads : int
Number of attention heads.
feat_drop : float, optional
Dropout rate on feature. Defaults: ``0``.
attn_drop : float, optional
Dropout rate on attention weight. Defaults: ``0``.
negative_slope : float, optional
LeakyReLU angle of negative slope. Defaults: ``0.2``.
activation : callable activation function/layer or None, optional.
If not None, applies an activation function to the updated node features.
Default: ``None``.
Examples
--------
>>> import jax
>>> import jax.numpy as jnp
>>> import galax
>>> g = galax.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
>>> g = g.add_self_loop()
>>> g = g.set_ndata("h", jnp.ones((6, 10)))
>>> gat = GAT(2, 4, deterministic=True)
>>> params = gat.init(jax.random.PRNGKey(2666), g)
>>> g = gat.apply(params, g)
>>> x = g.ndata['h']
>>> x.shape
(6, 4, 2)
"""
features: int
num_heads: int
feat_drop: Optional[float] = 0.0
attn_drop: Optional[float] = 0.0
negative_slope: float = 0.2
activation: Optional[Callable] = None
deterministic: bool = True
use_bias: bool = True
[docs] def setup(self):
self.fc = nn.Dense(
self.features * self.num_heads, use_bias=False,
kernel_init=nn.initializers.variance_scaling(
3.0, "fan_avg", "uniform"
),
)
self.attn_l = nn.Dense(
1,
kernel_init=nn.initializers.variance_scaling(
3.0, "fan_avg", "uniform"
),
)
self.attn_r = nn.Dense(
1,
kernel_init=nn.initializers.variance_scaling(
3.0, "fan_avg", "uniform"
),
)
if self.use_bias:
self.bias = self.param(
"bias",
nn.zeros,
(self.num_heads, self.features),
)
self.dropout_feat = nn.Dropout(self.feat_drop, deterministic=self.deterministic)
self.dropout_attn = nn.Dropout(self.attn_drop, deterministic=self.deterministic)
def __call__(self, graph, field="h", etype="E_"):
h = graph.ndata[field]
h0 = h
h = self.dropout_feat(h)
h = self.fc(h)
h = h.reshape(h.shape[:-1] + (self.num_heads, self.features))
el = self.attn_l(h)
er = self.attn_r(h)
graph = graph.ndata.set(field, h)
graph = graph.ndata.set("er", er)
graph = graph.ndata.set("el", el)
e = graph.edges[etype].src["er"] + graph.edges[etype].dst["el"]
e = nn.leaky_relu(e, self.negative_slope)
a = fn.segment_softmax(e, graph.edges[etype].dst.idxs, graph.number_of_nodes())
a = self.dropout_attn(a)
graph = graph.edata.set("a", a)
graph = graph.update_all(
fn.u_mul_e(field, "a", "m"),
fn.sum("m", field)
)
if self.use_bias:
graph = fn.apply_nodes(
lambda x: x + self.bias, in_field=field, out_field=field
)(graph)
if self.activation is not None:
graph = fn.apply_nodes(
self.activation, in_field=field, out_field=field
)(graph)
return graph