galax.function.segment_mean

galax.function.segment_mean(data: jax._src.numpy.ndarray.ndarray, segment_ids: jax._src.numpy.ndarray.ndarray, num_segments: Optional[int] = None, indices_are_sorted: bool = False, unique_indices: bool = False)[source]

Returns mean for each segment.

  • Shamelessly stolen from jraph.utils

Parameters
  • data (jnp.ndarray) – the values which are averaged segment-wise.

  • segment_ids (jnp.ndarray) – indices for the segments.

  • num_segments (Optional[int]) – total number of segments.

  • indices_are_sorted (bool=False) – whether segment_ids is known to be sorted.

  • unique_indices (bool=False) – whether segment_ids is known to be free of duplicates.

Returns

The data after segmentation sum.

Return type

jnp.ndarray