galax.function.segment_mean
- galax.function.segment_mean(data: jax._src.numpy.ndarray.ndarray, segment_ids: jax._src.numpy.ndarray.ndarray, num_segments: Optional[int] = None, indices_are_sorted: bool = False, unique_indices: bool = False)[source]
Returns mean for each segment.
Shamelessly stolen from jraph.utils
- Parameters
data (jnp.ndarray) – the values which are averaged segment-wise.
segment_ids (jnp.ndarray) – indices for the segments.
num_segments (Optional[int]) – total number of segments.
indices_are_sorted (bool=False) – whether
segment_ids
is known to be sorted.unique_indices (bool=False) – whether
segment_ids
is known to be free of duplicates.
- Returns
The data after segmentation sum.
- Return type
jnp.ndarray