galax.data.dataloader.PrixFixeDataLoader

class galax.data.dataloader.PrixFixeDataLoader(graphs: Sequence[galax.heterograph.HeteroGraph], batch_size: int = 1)[source]

Bases: object

A helper object that shuffles and iterates over graphs.

Parameters
  • graphs (Sequence[HeteroGraph]) – Graphs to iterate over.

  • batch_size (int = 1) – Batch size.

Examples

>>> import galax
>>> g0 = galax.graph(((0, 1), (1, 2)))
>>> g1 = galax.graph(((0, 1, 2), (1, 2, 3)))
>>> g2 = galax.graph(((0, 1, 2, 3), (1, 2, 3, 4)))
>>> dataloader = PrixFixeDataLoader((g0, g1, g2), batch_size=3)
>>> dataloader.max_num_edges.item()
9
>>> dataloader.max_num_nodes.item()
12
>>> g = next(iter(dataloader))
>>> int(g.number_of_nodes())
12
>>> int(g.number_of_edges())
9
__init__(graphs: Sequence[galax.heterograph.HeteroGraph], batch_size: int = 1)[source]

Methods

__init__(graphs[, batch_size])