galax.function.segment_softmax
- galax.function.segment_softmax(data: jax._src.numpy.ndarray.ndarray, segment_ids: jax._src.numpy.ndarray.ndarray, num_segments: Optional[int] = None, indices_are_sorted: bool = False, unique_indices: bool = False)[source]
Computes a segment-wise softmax. For a given tree of logits that can be divded into segments, computes a softmax over the segments. logits = jnp.ndarray([1.0, 2.0, 3.0, 1.0, 2.0]) segment_ids = jnp.ndarray([0, 0, 0, 1, 1]) segment_softmax(logits, segments) >> DeviceArray([0.09003057, 0.24472848, 0.66524094, 0.26894142, 0.7310586], >> dtype=float32) Args: logits: an array of logits to be segment softmaxed. segment_ids: an array with integer dtype that indicates the segments of
data (along its leading axis) to be maxed over. Values can be repeated and need not be sorted. Values outside of the range [0, num_segments) are dropped and do not contribute to the result.
- num_segments: optional, an int with positive value indicating the number of
segments. The default is
jnp.maximum(jnp.max(segment_ids) + 1, jnp.max(-segment_ids))but sincenum_segmentsdetermines the size of the output, a static value must be provided to usesegment_sumin ajit-compiled function.
indices_are_sorted: whether
segment_idsis known to be sorted unique_indices: whethersegment_idsis known to be free of duplicates Returns: The segment softmax-edlogits.