This document is relevant for: Inf2, Trn1, Trn2

nki.isa.range_select#

nki.isa.range_select(*, on_true_tile, comp_op0, comp_op1, bound0, bound1, reduce_cmd=reduce_cmd.idle, reduce_res=None, reduce_op=<function amax>, range_start=0, on_false_value=<property object>, mask=None, dtype=None, **kwargs)[source]#

Select elements from on_true_tile based on comparison with bounds using Vector Engine.

For each element in on_true_tile, compares its free dimension index + range_start against bound0 and bound1 using the specified comparison operators (comp_op0 and comp_op1). If both comparisons evaluate to True, copies the element to the output; otherwise uses on_false_value.

Additionally performs a reduction operation specified by reduce_op on the results, storing the reduction result in reduce_res.

Note on numerical stability:

In self-attention, we often have this instruction sequence: range_select (VectorE) -> reduce_res -> activation (ScalarE). When range_select outputs a full row of fill_value, caution is needed to avoid NaN in the activation instruction that subtracts the output of range_select by reduce_res (max value):

  • If dtype and reduce_res are both FP32, we should not hit any NaN issue since FP32_MIN - FP32_MIN = 0. Exponentiation on 0 is stable (1.0 exactly).

  • If dtype is FP16/BF16/FP8, the fill_value in the output tile will become -INF since HW performs a downcast from FP32_MIN to a smaller dtype. In this case, you must make sure reduce_res uses FP32 dtype to avoid NaN in activation. NaN can be avoided because activation always upcasts input tiles to FP32 to perform math operations: -INF - FP32_MIN = -INF. Exponentiation on -INF is stable (0.0 exactly).

Constraints:

The comparison operators must be one of:

  • np.equal

  • np.less

  • np.less_equal

  • np.greater

  • np.greater_equal

Partition dim sizes must match across on_true_tile, bound0, and bound1:

  • bound0 and bound1 must have one element per partition

  • on_true_tile must be one of the FP dtypes, and bound0/bound1 must be FP32 types.

The comparison with bound0, bound1, and free dimension index is done in FP32. Make sure range_start + free dimension index is within 2^24 range.

Estimated instruction cost:

max(MIN_II, N) Vector Engine cycles, where:

  • N is the number of elements per partition in on_true_tile, and

  • MIN_II is the minimum instruction initiation interval for small input tiles.

  • MIN_II is roughly 64 engine cycles.

Numpy equivalent:

indices = np.zeros(on_true_tile.shape)
indices[:] = range_start + np.arange(on_true_tile[0].size)

mask = comp_op0(indices, bound0) & comp_op1(indices, bound1)
select_out_tile = np.where(mask, on_true_tile, on_false_value)
reduce_tile = reduce_op(select_out_tile, axis=1, keepdims=True)
Parameters:
  • on_true_tile – input tile containing elements to select from

  • on_false_value – constant value to use when selection condition is False. Due to HW constraints, this must be FP32_MIN FP32 bit pattern

  • comp_op0 – first comparison operator

  • comp_op1 – second comparison operator

  • bound0 – tile with one element per partition for first comparison

  • bound1 – tile with one element per partition for second comparison

  • reduce_op – reduction operator to apply on across the selected output. Currently only np.max is supported.

  • reduce_res – optional tile to store reduction results.

  • range_start – starting base offset for index array for the free dimension of on_true_tile Defaults to 0, and must be a compiler time integer.

  • mask – (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see NKI API Masking for details)

  • dtype – (optional) data type to cast the output type to (see Supported Data Types for more information); if not specified, it will default to be the same as the data type of the input tile.

Returns:

output tile with selected elements

Example:

import neuronxcc.nki as nki
import neuronxcc.nki.isa as nisa
import neuronxcc.nki.language as nl
import numpy as np
...

##################################################################
# Example 1: # Select elements where 
# bound0 <= range_start + index < bound1 and compute max reduction
# 
# on_false_value must be nl.fp32.min
##################################################################
on_true_tile = nl.load(on_true[...])
bound0_tile = nl.load(bound0[...])
bound1_tile = nl.load(bound1[...])

reduce_res_tile = nl.ndarray((on_true.shape[0], 1), dtype=nl.float32, buffer=nl.sbuf)
result = nl.ndarray(on_true.shape, dtype=nl.float32, buffer=nl.sbuf)

result[...] = nisa.range_select(
    on_true_tile=on_true_tile,
    comp_op0=compare_op0,
    comp_op1=compare_op1,
    bound0=bound0_tile,
    bound1=bound1_tile,
    reduce_cmd=nisa.reduce_cmd.reset_reduce,
    reduce_res=reduce_res_tile,
    reduce_op=np.max,
    range_start=range_start,
    on_false_value=nl.fp32.min
)

nl.store(select_res[...], value=result[...])
nl.store(reduce_result[...], value=reduce_res_tile[...])

Alternatively, reduce_cmd can be used to chain multiple calls to the same accumulation register to accumulate across multiple range_select calls. For example:

import neuronxcc.nki as nki
import neuronxcc.nki.isa as nisa
import neuronxcc.nki.language as nl
import numpy as np
...

##################################################################
# Example 2.a: Initialize reduction with first range_select
# Notice we don't pass reduce_res since the accumulation
# register keeps track of the accumulation until we're ready to 
# read it. Also we use reset_reduce in order to "clobber" or zero
# out the accumulation register before we start accumulating.
#
# Note: Since the type of these tensors are fp32, we use nl.fp32.min
# for on_false_value due to HW constraints.
##################################################################
on_true_tile = nl.load(on_true[...])
bound0_tile = nl.load(bound0[...])
bound1_tile = nl.load(bound1[...])

reduce_res_sbuf = nl.ndarray((on_true.shape[0], 1), dtype=np.float32, buffer=nl.sbuf)
result_sbuf = nl.ndarray(on_true.shape, dtype=np.float32, buffer=nl.sbuf)

result_sbuf[...] = nisa.range_select(
    on_true_tile=on_true_tile,
    comp_op0=compare_op0,
    comp_op1=compare_op1,
    bound0=bound0_tile,
    bound1=bound1_tile,
    reduce_cmd=nisa.reduce_cmd.reset_reduce,
    reduce_op=np.max,
    range_start=range_start,
    on_false_value=nl.fp32.min
)

##################################################################
# Example 2.b: Chain multiple range_select operations 
# with reduction in an affine loop. Adding ones just lets us ensure the reduction 
# gets updated with new values.
##################################################################
ones = nl.full(on_true.shape, fill_value=1, dtype=np.float32, buffer=nl.sbuf)
# we are going to loop as if we're tiling on the partition dimension    
iteration_step_size = on_true_tile.shape[0]

# Perform chained operations using an affine loop index for range_start
for i in range(1, 2):
    # Update input values
    on_true_tile[...] = nl.add(on_true_tile, ones)
    
    # Continue reduction with updated values
    # notice, we still don't have reduce_res specified
    result_sbuf[...] = nisa.range_select(
        on_true_tile=on_true_tile,
        comp_op0=compare_op0,
        comp_op1=compare_op1,
        bound0=bound0_tile,
        bound1=bound1_tile,
        reduce_cmd=nisa.reduce_cmd.reduce,
        reduce_op=np.max,
        # we can also use index expressions for setting the start of the range
        range_start=range_start + (i * iteration_step_size),
        on_false_value=nl.fp32.min
    )

range_start = range_start + (2 * iteration_step_size)
##################################################################
# Example 2.c: Final iteration, we actually want the results to 
# return to the user so we pass reduce_res argument so the 
# reduction  will be written from the accumulation 
# register to reduce_res_tile
##################################################################
on_true_tile[...] = nl.add(on_true_tile, ones)
result_sbuf[...] = nisa.range_select(
    on_true_tile=on_true_tile,
    comp_op0=compare_op0,
    comp_op1=compare_op1,
    bound0=bound0_tile,
    bound1=bound1_tile,
    reduce_cmd=nisa.reduce_cmd.reduce,
    reduce_res=reduce_res_sbuf[...],
    reduce_op=np.max,
    range_start=range_start,
    on_false_value=nl.fp32.min
)

nl.store(select_res[...], value=result_sbuf[...])
nl.store(reduce_result[...], value=reduce_res_sbuf[...])

This document is relevant for: Inf2, Trn1, Trn2