galax.nn.zoo.gat.GAT
- class galax.nn.zoo.gat.GAT(features: int, num_heads: int, feat_drop: Optional[float] = 0.0, attn_drop: Optional[float] = 0.0, negative_slope: float = 0.2, activation: Optional[Callable] = None, deterministic: bool = True, use_bias: bool = True, parent: Optional[Union[Type[flax.linen.module.Module], Type[flax.core.scope.Scope], Type[flax.linen.module._Sentinel]]] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)[source]
Bases:
flax.linen.module.Module
Apply Graph Attention Network over an input signal. .. math:
h_i^{(l+1)} = \sum_{j\in \mathcal{N}(i)} \alpha_{i,j} W^{(l)} h_j^{(l)}
where \(\alpha_{ij}\) is the attention score bewteen node \(i\) and node \(j\): .. math:
\alpha_{ij}^{l} &= \mathrm{softmax_i} (e_{ij}^{l}) e_{ij}^{l} &= \mathrm{LeakyReLU}\left(\vec{a}^T [W h_{i} \| W h_{j}]\right)
- Parameters
features (int) – Features
num_heads (int) – Number of attention heads.
feat_drop (float, optional) – Dropout rate on feature. Defaults:
0
.attn_drop (float, optional) – Dropout rate on attention weight. Defaults:
0
.negative_slope (float, optional) – LeakyReLU angle of negative slope. Defaults:
0.2
.activation (callable activation function/layer or None, optional.) – If not None, applies an activation function to the updated node features. Default:
None
.
Examples
>>> import jax >>> import jax.numpy as jnp >>> import galax >>> g = galax.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])) >>> g = g.add_self_loop() >>> g = g.set_ndata("h", jnp.ones((6, 10))) >>> gat = GAT(2, 4, deterministic=True) >>> params = gat.init(jax.random.PRNGKey(2666), g) >>> g = gat.apply(params, g) >>> x = g.ndata['h'] >>> x.shape (6, 4, 2)
- __init__(features: int, num_heads: int, feat_drop: Optional[float] = 0.0, attn_drop: Optional[float] = 0.0, negative_slope: float = 0.2, activation: Optional[Callable] = None, deterministic: bool = True, use_bias: bool = True, parent: Optional[Union[Type[flax.linen.module.Module], Type[flax.core.scope.Scope], Type[flax.linen.module._Sentinel]]] = <flax.linen.module._Sentinel object>, name: Optional[str] = None) None
Methods
__init__
(features, num_heads[, feat_drop, ...])apply
(variables, *args[, rngs, method, ...])Applies a module method to variables and returns output and modified variables.
bind
(variables, *args[, rngs, mutable])Creates an interactive Module instance by binding variables and RNGs.
clone
(*[, parent])Creates a clone of this Module, with optionally updated arguments.
get_variable
(col, name[, default])Retrieves the value of a Variable.
has_rng
(name)Returns true if a PRNGSequence with name name exists.
has_variable
(col, name)Checks if a variable of given collection and name exists in this Module.
init
(rngs, *args[, method, mutable])Initializes a module method with variables and returns modified variables.
init_with_output
(rngs, *args[, method, mutable])Initializes a module method with variables and returns output and modified variables.
Returns true if the collection col is mutable.
make_rng
(name)Returns a new RNG key from a given RNG sequence for this Module.
param
(name, init_fn, *init_args)Declares and returns a parameter in this Module.
put_variable
(col, name, value)Sets the value of a Variable.
setup
()Initializes a Module lazily (similar to a lazy
__init__
).sow
(col, name, value[, reduce_fn, init_fn])Stores a value in a collection.
tabulate
(rngs, *args[, method, mutable, ...])Creates a summary of the Module represented as a table.
variable
(col, name[, init_fn])Declares and returns a variable in this Module.
Attributes
activation
attn_drop
deterministic
feat_drop
name
negative_slope
parent
scope
use_bias
Returns the variables in this module.
features
num_heads
- apply(variables: Mapping[str, Mapping[str, Any]], *args, rngs: Optional[Dict[str, Any]] = None, method: Optional[Callable[[...], Any]] = None, mutable: Union[bool, str, Collection[str], flax.core.scope.DenyList] = False, capture_intermediates: Union[bool, Callable[[flax.linen.module.Module, str], bool]] = False, **kwargs) Union[Any, Tuple[Any, flax.core.frozen_dict.FrozenDict[str, Mapping[str, Any]]]]
Applies a module method to variables and returns output and modified variables.
Note that method should be set if one would like to call apply on a different class method than
__call__
. For instance, suppose a Transformer modules has a method called encode, then the following calls apply on that method:model = Transformer() encoded = model.apply({'params': params}, x, method=Transformer.encode)
If a function instance is provided, the unbound function is used. For instance, the example below is equivalent to the one above:
encoded = model.apply({'params': params}, x, method=model.encode)
Note
method
can also be a function that is not defined inTransformer
. In that case, the function should have at least one argument representing an instance of the Module class:def other_fn(instance, ...): instance.some_module_attr(...) ... model.apply({'params': params}, x, method=other_fn)
- Args:
- variables: A dictionary containing variables keyed by variable
collections. See
flax.core.variables
for more details about variables.
*args: Named arguments passed to the specified apply method. rngs: a dict of PRNGKeys to initialize the PRNG sequences.
The “params” PRNG sequence is used to initialize parameters.
- method: A function to call apply on. This is generally a function in the
module. If provided, applies this method. If not provided, applies the
__call__
method of the module.- mutable: Can be bool, str, or list. Specifies which collections should be
treated as mutable:
bool
: all/no collections are mutable.str
: The name of a single mutable collection.list
: A list of names of mutable collections.- capture_intermediates: If True, captures intermediate return values
of all Modules inside the “intermediates” collection. By default only the return values of all
__call__
methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored.
**kwargs: Keyword arguments passed to the specified apply method.
- Returns:
If
mutable
is False, returns output. If any collections are mutable, returns(output, vars)
, wherevars
are is a dict of the modified collections.
- bind(variables: Mapping[str, Mapping[str, Any]], *args, rngs: Optional[Dict[str, Any]] = None, mutable: Union[bool, str, Collection[str], flax.core.scope.DenyList] = False)
Creates an interactive Module instance by binding variables and RNGs.
bind
provides an “interactive” instance of a Module directly without transforming a function withapply
. This is particalary useful for debugging and interactive use cases like notebooks where a function would limit the ability split up code into different cells.Once the variables (and optionally RNGs) are bound to a
Module
it becomes a stateful object. Note that idiomatic JAX is functional and therefore an interactive instance does not mix well well with vanilla JAX APIs.bind()
should only be used for interactive experimentation, and in all other cases we strongly encourage to useapply()
instead.Example:
import jax import jax.numpy as jnp import flax.linen as nn class AutoEncoder(nn.Module): def setup(self): self.encoder = nn.Dense(3) self.decoder = nn.Dense(5) def __call__(self, x): return self.decoder(self.encoder(x)) x = jnp.ones((16, 9)) ae = AutoEncoder() variables = ae.init(jax.random.PRNGKey(0), x) model = ae.bind(variables) z = model.encoder(x) x_reconstructed = model.decoder(z)
- Args:
- variables: A dictionary containing variables keyed by variable
collections. See
flax.core.variables
for more details about variables.
*args: Named arguments (not used). rngs: a dict of PRNGKeys to initialize the PRNG sequences. mutable: Can be bool, str, or list. Specifies which collections should be
- treated as mutable:
bool
: all/no collections are mutable.str
: The name of a single mutable collection.list
: A list of names of mutable collections.
- Returns:
A copy of this instance with bound variables and RNGs.
- clone(*, parent: Optional[Union[flax.core.scope.Scope, flax.linen.module.Module]] = None, **updates) flax.linen.module.Module
Creates a clone of this Module, with optionally updated arguments.
- Args:
- parent: The parent of the clone. The clone will have no parent if no
explicit parent is specified.
**updates: Attribute updates.
- Returns:
A clone of the this Module with the updated attributes and parent.
- get_variable(col: str, name: str, default: Optional[flax.linen.module.T] = None) flax.linen.module.T
Retrieves the value of a Variable.
- Args:
col: the variable collection. name: the name of the variable. default: the default value to return if the variable does not exist in
this scope.
- Returns:
The value of the input variable, of the default value if the variable doesn’t exist in this scope.
- has_rng(name: str) bool
Returns true if a PRNGSequence with name name exists.
- has_variable(col: str, name: str) bool
Checks if a variable of given collection and name exists in this Module.
See
flax.core.variables
for more explanation on variables and collections.- Args:
col: The variable collection name. name: The name of the variable.
- Returns:
True if the variable exists.
- init(rngs: Union[Any, Dict[str, Any]], *args, method: Optional[Callable[[...], Any]] = None, mutable: Union[bool, str, Collection[str], flax.core.scope.DenyList] = DenyList(deny='intermediates'), **kwargs) flax.core.frozen_dict.FrozenDict[str, Mapping[str, Any]]
Initializes a module method with variables and returns modified variables.
Jitting init initializes a model lazily using only the shapes of the provided arguments, and avoids computing the forward pass with actual values. Example:
jit_init = jax.jit(SomeModule(...).init) jit_init(rng, jnp.ones(input_shape, jnp.float32))
- Args:
rngs: The rngs for the variable collections. *args: Named arguments passed to the init function. method: An optional method. If provided, applies this method. If not
provided, applies the
__call__
method.- mutable: Can be bool, str, or list. Specifies which collections should be
treated as mutable:
bool
: all/no collections are mutable.str
: The name of a single mutable collection.list
: A list of names of mutable collections. By default all collections except “intermediates” are mutable.
**kwargs: Keyword arguments passed to the init function.
- Returns:
The initialized variable dict.
- init_with_output(rngs: Union[Any, Dict[str, Any]], *args, method: Optional[Callable[[...], Any]] = None, mutable: Union[bool, str, Collection[str], flax.core.scope.DenyList] = DenyList(deny='intermediates'), **kwargs) Tuple[Any, flax.core.frozen_dict.FrozenDict[str, Mapping[str, Any]]]
Initializes a module method with variables and returns output and modified variables.
- Args:
rngs: The rngs for the variable collections. *args: Named arguments passed to the init function. method: An optional method. If provided, applies this method. If not
provided, applies the
__call__
method.- mutable: Can be bool, str, or list. Specifies which collections should be
treated as mutable:
bool
: all/no collections are mutable.str
: The name of a single mutable collection.list
: A list of names of mutable collections. By default all collections except “intermediates” are mutable.
**kwargs: Keyword arguments passed to the init function.
- Returns:
(output, vars)`, where
vars
are is a dict of the modified collections.
- is_mutable_collection(col: str) bool
Returns true if the collection col is mutable.
- make_rng(name: str) Any
Returns a new RNG key from a given RNG sequence for this Module.
The new RNG key is split from the previous one. Thus, every call to make_rng returns a new RNG key, while still guaranteeing full reproducibility.
TODO: Link to Flax RNG design note.
- Args:
name: The RNG sequence name.
- Returns:
The newly generated RNG key.
- param(name: str, init_fn: Callable[[...], flax.linen.module.T], *init_args) flax.linen.module.T
Declares and returns a parameter in this Module.
Parameters are read-only variables in the collection named “params”. See
flax.core.variables
for more details on variables.The first argument of init_fn is assumed to be a PRNG key, which is provided automatically and does not have to be passed using init_args:
mean = self.param('mean', lecun_normal(), (2, 2))
In the example above, the function lecun_normal expects two arguments: key and shape, but only shape has to be provided explicitly; key is set automatically using the PRNG for params that is passed when initializing the module using
init()
.- Args:
name: The parameter name. init_fn: The function that will be called to compute the initial value
of this variable. This function will only be called the first time this parameter is used in this module.
*init_args: The arguments to pass to init_fn.
- Returns:
The value of the initialized parameter.
- put_variable(col: str, name: str, value: Any)
Sets the value of a Variable.
- Args:
col: the variable collection. name: the name of the variable. value: the new value of the variable.
Returns:
- setup()[source]
Initializes a Module lazily (similar to a lazy
__init__
).setup
is called once lazily on a module instance when a module is bound, immediately before any other methods like__call__
are invoked, or before asetup
-defined attribute on self is accessed.This can happen in three cases:
Immediately when invoking
apply()
,init()
orinit_and_output()
.Once the module is given a name by being assigned to an attribute of another module inside the other module’s
setup
method (see__setattr__()
):class MyModule(nn.Module): def setup(self): submodule = Conv(...) # Accessing `submodule` attributes does not yet work here. # The following line invokes `self.__setattr__`, which gives # `submodule` the name "conv1". self.conv1 = submodule # Accessing `submodule` attributes or methods is now safe and # either causes setup() to be called once.
Once a module is constructed inside a method wrapped with
compact()
, immediately before another method is called orsetup
defined attribute is accessed.
- sow(col: str, name: str, value: flax.linen.module.T, reduce_fn: Callable[[flax.linen.module.K, flax.linen.module.T], flax.linen.module.K] = <function <lambda>>, init_fn: Callable[[], flax.linen.module.K] = <function <lambda>>) bool
Stores a value in a collection.
Collections can be used to collect intermediate values without the overhead of explicitly passing a container through each Module call.
If the target collection is not mutable sow behaves like a no-op and returns False.
Example:
import jax import jax.numpy as jnp import flax.linen as nn class Foo(nn.Module): @nn.compact def __call__(self, x): h = nn.Dense(4)(x) self.sow('intermediates', 'h', h) return nn.Dense(2)(h) x = jnp.ones((16, 9)) model = Foo() variables = model.init(jax.random.PRNGKey(0), x) y, state = model.apply(variables, x, mutable=['intermediates']) print(state['intermediates']) # {'h': (...,)}
By default the values are stored in a tuple and each stored value is appended at the end. This way all intermediates can be tracked when the same module is called multiple times. Alternatively, a custom init/reduce function can be passed:
class Foo2(nn.Module): @nn.compact def __call__(self, x): init_fn = lambda: 0 reduce_fn = lambda a, b: a + b self.sow('intermediates', 'h', x, init_fn=init_fn, reduce_fn=reduce_fn) self.sow('intermediates', 'h', x * 2, init_fn=init_fn, reduce_fn=reduce_fn) return x model = Foo2() variables = model.init(jax.random.PRNGKey(0), x) y, state = model.apply(variables, jnp.ones((1, 1)), mutable=['intermediates']) print(state['intermediates']) # ==> {'h': [[3.]]}
- Args:
col: The name of the variable collection. name: The name of the variable. value: The value of the variable. reduce_fn: The function used to combine the existing value with
the new value. The default is to append the value to a tuple.
- init_fn: For the first value stored, reduce_fn will be passed
the result of init_fn together with the value to be stored. The default is an empty tuple.
- Returns:
True if the value has been stored successfully, False otherwise.
- tabulate(rngs: Union[Any, Dict[str, Any]], *args, method: Optional[Callable[[...], Any]] = None, mutable: Union[bool, str, Collection[str], flax.core.scope.DenyList] = True, depth: Optional[int] = None, exclude_methods: Sequence[str] = (), **kwargs) str
Creates a summary of the Module represented as a table.
This method has the same signature as init, but instead of returning the variables, it returns the string summarizing the Module in a table. tabulate uses jax.eval_shape to run the forward computation without consuming any FLOPs or allocating memory.
Example:
import jax import jax.numpy as jnp import flax.linen as nn class Foo(nn.Module): @nn.compact def __call__(self, x): h = nn.Dense(4)(x) return nn.Dense(2)(h) x = jnp.ones((16, 9)) print(Foo().tabulate(jax.random.PRNGKey(0), x))
This gives the following output:
Foo Summary ┏━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓ ┃ path ┃ outputs ┃ params ┃ ┡━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩ │ Inputs │ float32[16,9] │ │ ├─────────┼───────────────┼──────────────────────┤ │ Dense_0 │ float32[16,4] │ bias: float32[4] │ │ │ │ kernel: float32[9,4] │ │ │ │ │ │ │ │ 40 (160 B) │ ├─────────┼───────────────┼──────────────────────┤ │ Dense_1 │ float32[16,2] │ bias: float32[2] │ │ │ │ kernel: float32[4,2] │ │ │ │ │ │ │ │ 10 (40 B) │ ├─────────┼───────────────┼──────────────────────┤ │ Foo │ float32[16,2] │ │ ├─────────┼───────────────┼──────────────────────┤ │ │ Total │ 50 (200 B) │ └─────────┴───────────────┴──────────────────────┘ Total Parameters: 50 (200 B)
Note: rows order in the table does not represent execution order, instead it aligns with the order of keys in variables which are sorted alphabetically.
- Args:
rngs: The rngs for the variable collections. *args: The arguments to the forward computation. method: An optional method. If provided, applies this method. If not
provided, applies the
__call__
method.- mutable: Can be bool, str, or list. Specifies which collections should be
treated as mutable:
bool
: all/no collections are mutable.str
: The name of a single mutable collection.list
: A list of names of mutable collections. By default all collections except ‘intermediates’ are mutable.- depth: controls how many submodule deep the summary can go. By default its
None which means no limit. If a submodule is not shown because of the depth limit, its parameter count and bytes will be added to the row of its first shown ancestor such that the sum of all rows always adds up to the total number of parameters of the Module.
- exclude_methods: A sequence of strings that specifies which methods should
be ignored. In case a module calls a helper method from its main method, use this argument to exclude the helper method from the summary to avoid ambiguity.
**kwargs: keyword arguments to pass to the forward computation.
- Returns:
A string summarizing the Module.
- variable(col: str, name: str, init_fn: Optional[Callable[[...], Any]] = None, *init_args) flax.core.scope.Variable
Declares and returns a variable in this Module.
See
flax.core.variables
for more information. See alsoparam()
for a shorthand way to define read-only variables in the “params” collection.Contrary to
param()
, all arguments passing using init_fn should be passed on explicitly:key = self.make_rng('stats') mean = self.variable('stats', 'mean', lecun_normal(), key, (2, 2))
In the example above, the function lecun_normal expects two arguments: key and shape, and both have to be passed on. The PRNG for stats has to be provided explicitly when calling
init()
andapply()
.- Args:
col: The variable collection name. name: The variable name. init_fn: The function that will be called to compute the initial value
of this variable. This function will only be called the first time this variable is used in this module. If None, the variable must already be initialized otherwise an error is raised.
*init_args: The arguments to pass to init_fn.
- Returns:
A
flax.core.variables.Variable
that can be read or set via “.value” attribute. Throws an error if the variable exists already.
- property variables: Mapping[str, Mapping[str, Any]]
Returns the variables in this module.