galax.function.segment_max

galax.function.segment_max(*args, **kwargs)[source]

Alias of jax.ops.segment_max with nan_to_num.