pypto.index_add_#
产品支持情况#
产品 |
是否支持 |
|---|---|
Ascend 950PR/Ascend 950DT |
√ |
Atlas A3 训练系列产品/Atlas A3 推理系列产品 |
√ |
Atlas A2 训练系列产品/Atlas A2 推理系列产品 |
√ |
功能说明#
将source的每一块数据乘以缩放因子alpha(默认为1)加到input的相应数据块上,其中索引和数据块方向由index和dim指定。
函数原型#
index_add_(input: Tensor, dim: int, index: Tensor, source: Tensor, *, alpha: Union[int, float] = 1) -> Tensor
参数说明#
参数名 |
输入/输出 |
说明 |
|---|---|---|
input |
输入 |
源操作数。 |
dim |
输入 |
int 类型,加法作用到 input 的维度; |
index |
输入 |
源操作数,值代表 input 所在 dim 轴的索引; |
source |
输入 |
需要加到 input 的源操作数; |
alpha |
输入 |
标量,关键字参数; |
返回值说明#
原地操作返回 input
约束说明#
index必须是整数类型(DT_INT32 或 DT_INT64),值不超过input在dim维度上的Shape大小,维数为1,Shape大小与 source 所在dim轴的Shape大小相同;
dim为int类型,取值范围:-input.dim <= dim < input.dim;
input和source的数据类型和维数均相同;
input.shape和source.shape的dim轴viewshape不可切,要求viewshape[dim]>=max(input.shape[dim], source.shape[dim]),其余维度的Shape大小不做限制;
TileShape的维度与input相同,input, source 的 dim 轴以及 index 均不可切,所有输入和输出的TileShape大小总和不能超过UB内存的大小。
调用示例#
TileShape设置示例#
调用该operation接口前,应通过set_vec_tile_shapes设置TileShape。
TileShape维度应和输出一致。
如输入input为[m, n, p],dim为1,输入source为[m, t, p],输入index为[t],输出为[m, n, p],TileShape设置为[m1, t1, p1],则m1, p1分别用于切分m, p轴。 n轴,t轴不可切,必须保证n轴t轴全载。
pypto.set_vec_tile_shapes(4, 16, 32)
接口调用示例#
x = pypto.tensor([2, 3], pypto.DT_INT32) # shape (2, 3)
source = pypto.tensor([3, 3], pypto.DT_INT32) # shape (3, 3)
index = pypto.tensor([3], pypto.DT_INT32) # shape (3,)
dim = 0
# use alpha
y = pypto.index_add_(x, dim, index, source, alpha=1)
# not use alpha
y = pypto.index_add_(x, dim, index, source)
结果示例如下:
输入数据 x: [[0 0 0],
[0 0 0]]
source: [[1 1 1],
[1 1 1],
[1 1 1]]
index: [0 1 0]
输出数据 y: [[2 2 2],
[1 1 1]] # shape (2, 3)