pypto.index_select#
产品支持情况#
产品 |
是否支持 |
|---|---|
Ascend 950PR/Ascend 950DT |
√ |
Atlas A3 训练系列产品/Atlas A3 推理系列产品 |
√ |
Atlas A2 训练系列产品/Atlas A2 推理系列产品 |
√ |
功能说明#
返回一个新的张量,该张量使用索引 index 中的元素沿维度 dim 对输入张量进行索引。
返回的张量与原始张量(输入)具有相同的维度数。第 dim 维度的大小与索引 index 的长度相同;其他维度的大小与原始张量相同。
函数原型#
index_select(input: Tensor, dim: int, index: Tensor) -> Tensor:
参数说明#
参数名 |
输入/输出 |
说明 |
|---|---|---|
input |
输入 |
源操作数。 |
dim |
输入 |
int 类型,索引的维度; |
index |
输入 |
源操作数; |
返回值说明#
返回输出 Tensor,输出 Tensor数据类型与 input 数据类型保持一致;输出 Tensor的Shape 有input、 dim 以及 index 共同确定,详见功能说明。
约束说明#
index必须是整数类型(DT_INT32 或 DT_INT64),值为合法索引,即不能超出 input.shape[dim];
dim为int类型,取值范围:-input.dim <= dim < input.dim。支持负数,负值会被解释为 dim + input.dim;
input.shape 的 dim 轴 viewshape 不可切,要求 viewshape[dim]>=input.shape[dim],其余维度的Shape大小不做限制;
TileShape的维度与result相同,用于切分result。TileShape 设置需保证 result 不超过UB大小,具体用法详见 TileShape设置示例
调用示例#
TileShape设置示例#
调用该operation接口前,应通过set_vec_tile_shapes设置TileShape。
TileShape 的维度设置须与输出张量保持一致,用于控制输出 Tile 块的大小。
以输入\( input[B,S,D]\) 、索引 \(index[T]\) 、轴 \( ext{axis}=-2\) 、输出 \(output[B,T,D]\) 为例:设 TileShape 为\([b_1, t_1, d_1]\),该配置直接作用于输出 output 的各维度,同时映射至输入与索引。其中 \(b_1\) 切分 input 的批次维 B ,\(d_1\) 切分 input 的特征维 D ,而输入的序列维 S (即轴 −2 )不参与切分,仅作为索引源; \(t_1\) 则作用于索引 index 的长度维 T 。Tile 内存占用须满足约束 \(b_1 dot t_1 dot d_1 dot ext{sizeof}(athbf{output}) < ext{UBSize}\)
接口调用示例#
x = pypto.tensor([3, 4], pypto.DT_FP32)
indices = pypto.tensor([2,], pypto.DT_INT32)
out1 = pypto.index_select(x, 0, indices)
out2 = pypto.index_select(x, 1, indices)
结果示例如下:
输入 x: [[ 0.1427, 0.0231, -0.5414, -1.0009],
[-0.4664, 0.2647, -0.1228, -1.1068],
[-1.1734, -0.6571, 0.7230, -0.6004]]
输入 index: [0, 2]
输出 out1 : [[ 0.1427, 0.0231, -0.5414, -1.0009],
[-1.1734, -0.6571, 0.7230, -0.6004]]
输出 out2 : [[ 0.1427, -0.5414],
[-0.4664, -0.1228],
[-1.1734, 0.7230]]