galax.function.segment_softmax
- galax.function.segment_softmax(data: jax._src.numpy.ndarray.ndarray, segment_ids: jax._src.numpy.ndarray.ndarray, num_segments: Optional[int] = None, indices_are_sorted: bool = False, unique_indices: bool = False)[source]
Computes a segment-wise softmax. For a given tree of logits that can be divded into segments, computes a softmax over the segments. logits = jnp.ndarray([1.0, 2.0, 3.0, 1.0, 2.0]) segment_ids = jnp.ndarray([0, 0, 0, 1, 1]) segment_softmax(logits, segments) >> DeviceArray([0.09003057, 0.24472848, 0.66524094, 0.26894142, 0.7310586], >> dtype=float32) Args: logits: an array of logits to be segment softmaxed. segment_ids: an array with integer dtype that indicates the segments of
data (along its leading axis) to be maxed over. Values can be repeated and need not be sorted. Values outside of the range [0, num_segments) are dropped and do not contribute to the result.
- num_segments: optional, an int with positive value indicating the number of
segments. The default is
jnp.maximum(jnp.max(segment_ids) + 1, jnp.max(-segment_ids))
but sincenum_segments
determines the size of the output, a static value must be provided to usesegment_sum
in ajit
-compiled function.
indices_are_sorted: whether
segment_ids
is known to be sorted unique_indices: whethersegment_ids
is known to be free of duplicates Returns: The segment softmax-edlogits
.