Source code for galax.nn.module

import abc
from typing import Callable, Optional
from flax import linen as nn
from ..function import apply_nodes, apply_edges

[docs]class ApplyNodes(nn.Module): layer: Callable def __call__(self, graph, field="h"): graph = graph.ndata.set(field, self.layer(graph.ndata[field])) return graph