This document is relevant for: Inf2
, Trn1
, Trn2
nki.isa.range_select#
- nki.isa.range_select(*, on_true_tile, comp_op0, comp_op1, bound0, bound1, reduce_cmd=reduce_cmd.idle, reduce_res=None, reduce_op=<function amax>, range_start=0, on_false_value=<property object>, mask=None, dtype=None, **kwargs)[source]#
Select elements from
on_true_tile
based on comparison with bounds using Vector Engine.For each element in
on_true_tile
, compares its free dimension index +range_start
againstbound0
andbound1
using the specified comparison operators (comp_op0
andcomp_op1
). If both comparisons evaluate to True, copies the element to the output; otherwise useson_false_value
.Additionally performs a reduction operation specified by
reduce_op
on the results, storing the reduction result inreduce_res
.Note on numerical stability:
In self-attention, we often have this instruction sequence:
range_select
(VectorE) ->reduce_res
->activation
(ScalarE). Whenrange_select
outputs a full row offill_value
, caution is needed to avoid NaN in the activation instruction that subtracts the output ofrange_select
byreduce_res
(max value):If
dtype
andreduce_res
are both FP32, we should not hit any NaN issue sinceFP32_MIN - FP32_MIN = 0
. Exponentiation on 0 is stable (1.0 exactly).If
dtype
is FP16/BF16/FP8, the fill_value in the output tile will become-INF
since HW performs a downcast from FP32_MIN to a smaller dtype. In this case, you must make sure reduce_res uses FP32dtype
to avoid NaN inactivation
. NaN can be avoided becauseactivation
always upcasts input tiles to FP32 to perform math operations:-INF - FP32_MIN = -INF
. Exponentiation on-INF
is stable (0.0 exactly).
Constraints:
The comparison operators must be one of:
np.equal
np.less
np.less_equal
np.greater
np.greater_equal
Partition dim sizes must match across
on_true_tile
,bound0
, andbound1
:bound0
andbound1
must have one element per partitionon_true_tile
must be one of the FP dtypes, andbound0/bound1
must be FP32 types.
The comparison with
bound0
,bound1
, and free dimension index is done in FP32. Make surerange_start
+ free dimension index is within 2^24 range.Estimated instruction cost:
max(MIN_II, N)
Vector Engine cycles, where:N
is the number of elements per partition inon_true_tile
, andMIN_II
is the minimum instruction initiation interval for small input tiles.MIN_II
is roughly 64 engine cycles.
Numpy equivalent:
indices = np.zeros(on_true_tile.shape) indices[:] = range_start + np.arange(on_true_tile[0].size) mask = comp_op0(indices, bound0) & comp_op1(indices, bound1) select_out_tile = np.where(mask, on_true_tile, on_false_value) reduce_tile = reduce_op(select_out_tile, axis=1, keepdims=True)
- Parameters:
on_true_tile – input tile containing elements to select from
on_false_value – constant value to use when selection condition is False. Due to HW constraints, this must be FP32_MIN FP32 bit pattern
comp_op0 – first comparison operator
comp_op1 – second comparison operator
bound0 – tile with one element per partition for first comparison
bound1 – tile with one element per partition for second comparison
reduce_op – reduction operator to apply on across the selected output. Currently only
np.max
is supported.reduce_res – optional tile to store reduction results.
range_start – starting base offset for index array for the free dimension of
on_true_tile
Defaults to 0, and must be a compiler time integer.mask – (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see NKI API Masking for details)
dtype – (optional) data type to cast the output type to (see Supported Data Types for more information); if not specified, it will default to be the same as the data type of the input tile.
- Returns:
output tile with selected elements
Example:
import neuronxcc.nki as nki import neuronxcc.nki.isa as nisa import neuronxcc.nki.language as nl import numpy as np ... ################################################################## # Example 1: # Select elements where # bound0 <= range_start + index < bound1 and compute max reduction # # on_false_value must be nl.fp32.min ################################################################## on_true_tile = nl.load(on_true[...]) bound0_tile = nl.load(bound0[...]) bound1_tile = nl.load(bound1[...]) reduce_res_tile = nl.ndarray((on_true.shape[0], 1), dtype=nl.float32, buffer=nl.sbuf) result = nl.ndarray(on_true.shape, dtype=nl.float32, buffer=nl.sbuf) result[...] = nisa.range_select( on_true_tile=on_true_tile, comp_op0=compare_op0, comp_op1=compare_op1, bound0=bound0_tile, bound1=bound1_tile, reduce_cmd=nisa.reduce_cmd.reset_reduce, reduce_res=reduce_res_tile, reduce_op=np.max, range_start=range_start, on_false_value=nl.fp32.min ) nl.store(select_res[...], value=result[...]) nl.store(reduce_result[...], value=reduce_res_tile[...])
Alternatively,
reduce_cmd
can be used to chain multiple calls to the same accumulation register to accumulate across multiple range_select calls. For example:import neuronxcc.nki as nki import neuronxcc.nki.isa as nisa import neuronxcc.nki.language as nl import numpy as np ... ################################################################## # Example 2.a: Initialize reduction with first range_select # Notice we don't pass reduce_res since the accumulation # register keeps track of the accumulation until we're ready to # read it. Also we use reset_reduce in order to "clobber" or zero # out the accumulation register before we start accumulating. # # Note: Since the type of these tensors are fp32, we use nl.fp32.min # for on_false_value due to HW constraints. ################################################################## on_true_tile = nl.load(on_true[...]) bound0_tile = nl.load(bound0[...]) bound1_tile = nl.load(bound1[...]) reduce_res_sbuf = nl.ndarray((on_true.shape[0], 1), dtype=np.float32, buffer=nl.sbuf) result_sbuf = nl.ndarray(on_true.shape, dtype=np.float32, buffer=nl.sbuf) result_sbuf[...] = nisa.range_select( on_true_tile=on_true_tile, comp_op0=compare_op0, comp_op1=compare_op1, bound0=bound0_tile, bound1=bound1_tile, reduce_cmd=nisa.reduce_cmd.reset_reduce, reduce_op=np.max, range_start=range_start, on_false_value=nl.fp32.min ) ################################################################## # Example 2.b: Chain multiple range_select operations # with reduction in an affine loop. Adding ones just lets us ensure the reduction # gets updated with new values. ################################################################## ones = nl.full(on_true.shape, fill_value=1, dtype=np.float32, buffer=nl.sbuf) # we are going to loop as if we're tiling on the partition dimension iteration_step_size = on_true_tile.shape[0] # Perform chained operations using an affine loop index for range_start for i in range(1, 2): # Update input values on_true_tile[...] = nl.add(on_true_tile, ones) # Continue reduction with updated values # notice, we still don't have reduce_res specified result_sbuf[...] = nisa.range_select( on_true_tile=on_true_tile, comp_op0=compare_op0, comp_op1=compare_op1, bound0=bound0_tile, bound1=bound1_tile, reduce_cmd=nisa.reduce_cmd.reduce, reduce_op=np.max, # we can also use index expressions for setting the start of the range range_start=range_start + (i * iteration_step_size), on_false_value=nl.fp32.min ) range_start = range_start + (2 * iteration_step_size) ################################################################## # Example 2.c: Final iteration, we actually want the results to # return to the user so we pass reduce_res argument so the # reduction will be written from the accumulation # register to reduce_res_tile ################################################################## on_true_tile[...] = nl.add(on_true_tile, ones) result_sbuf[...] = nisa.range_select( on_true_tile=on_true_tile, comp_op0=compare_op0, comp_op1=compare_op1, bound0=bound0_tile, bound1=bound1_tile, reduce_cmd=nisa.reduce_cmd.reduce, reduce_res=reduce_res_sbuf[...], reduce_op=np.max, range_start=range_start, on_false_value=nl.fp32.min ) nl.store(select_res[...], value=result_sbuf[...]) nl.store(reduce_result[...], value=reduce_res_sbuf[...])
This document is relevant for: Inf2
, Trn1
, Trn2