pypto.index_add_#
产品支持情况#
产品 |
是否支持 |
|---|---|
Ascend 950PR/Ascend 950DT |
√ |
Atlas A3 训练系列产品/Atlas A3 推理系列产品 |
√ |
Atlas A2 训练系列产品/Atlas A2 推理系列产品 |
√ |
功能说明#
将source的每一块数据乘以缩放因子alpha(默认为1)加到input的相应数据块上,其中索引和数据块方向由index和dim指定。例如, $\( input[index[i], :, :] += alpha * source[i, :, :],\ if \ dim == 0, \\ input[:, index[i], :] += alpha * source[:, i, :],\ if \ dim == 1,\\ input[:, :, index[i]] += alpha * source[:, :, i],\ if \ dim == 2.\)$
函数原型#
index_add_(input: Tensor, dim: int, index: Tensor, source: Tensor, *, alpha: Union[int, float] = 1) -> Tensor
参数说明#
参数名 |
输入/输出 |
说明 |
|---|---|---|
input |
输入 |
源操作数。 |
dim |
输入 |
int 类型,加法作用到 input 的维度; |
index |
输入 |
源操作数,值代表 input 所在 dim 轴的索引; |
source |
输入 |
需要加到 input 的源操作数; |
alpha |
输入 |
标量,关键字参数; |
返回值说明#
原地操作返回 input
约束说明#
index必须是整数类型(DT_INT32 或 DT_INT64),值不超过 input 在 dim 维度上的Shape大小,维数为1,Shape大小与 source 所在dim轴的Shape大小相同;
dim为int类型,取值范围:\(-input.dim\leq dim < input.dim\);
input和source的数据类型和维数均相同;
input.shape 和 source.shape 的非 dim 轴 ViewShape 不可切,即 \(ViewShape[i] \geq input.shape[i]=source.shape[i], i \ne dim\);
TileShape的维度与 source 相同,只用来切分 source 和 index,所有输入和输出的TileShape大小总和不能超过UB内存的大小。
调用示例#
TileShape设置示例#
调用该operation接口前,应通过set_vec_tile_shapes设置TileShape。
如输入input为[m, n, p],dim为1,输入source为[m, t, p],输入index为[t],输出为[m, n, p],TileShape设置为[m1, t1, p1],则m1, t1, p1分别用于切分source的 m, t, p轴。
pypto.set_vec_tile_shapes(4, 16, 32)
接口调用示例#
x = pypto.tensor([2, 3], pypto.DT_INT32) # shape (2, 3)
source = pypto.tensor([3, 3], pypto.DT_INT32) # shape (3, 3)
index = pypto.tensor([3], pypto.DT_INT32) # shape (3,)
dim = 0
# use alpha
y = pypto.index_add_(x, dim, index, source, alpha=1)
# not use alpha
y = pypto.index_add_(x, dim, index, source)
结果示例如下:
输入数据 x: [[0 0 0],
[0 0 0]]
source: [[1 1 1],
[1 1 1],
[1 1 1]]
index: [0 1 0]
输出数据 y: [[2 2 2],
[1 1 1]] # shape (2, 3)