pypto.index_add_#

产品支持情况#

产品

是否支持

Ascend 950PR/Ascend 950DT

Atlas A3 训练系列产品/Atlas A3 推理系列产品

Atlas A2 训练系列产品/Atlas A2 推理系列产品

功能说明#

将source的每一块数据乘以缩放因子alpha(默认为1)加到input的相应数据块上,其中索引和数据块方向由index和dim指定。

函数原型#

index_add_(input: Tensor, dim: int, index: Tensor, source: Tensor, *, alpha: Union[int, float] = 1) -> Tensor

参数说明#

参数名

输入/输出

说明

input

输入

源操作数。
支持的类型为:Tensor。
Tensor支持的数据类型为:DT_FP32,DT_FP16,DT_BF16,DT_INT16,DT_INT32。
不支持空Tensor;Shape仅支持2-4维;Shape Size不大于2147483647(即INT32_MAX)。

dim

输入

int 类型,加法作用到 input 的维度;
支持任意不超过 input 维数的值,详见约束说明。

index

输入

源操作数,值代表 input 所在 dim 轴的索引;
支持的类型为:Tensor。
Tensor支持的数据类型为:DT_INT32,DT_INT64;
不支持空 Tensor,Shape只支持1维,索引与 source 的 dim 轴索引一一对应,Shape大小与 source 所在 dim 轴的Shape大小相同。

source

输入

需要加到 input 的源操作数;
支持的类型为:Tensor。
Tensor的数据类型 与 input 相同。
Shape支持2-4维,所在 dim 轴的Shape大小与 index 相同,其他维度的Shape大小与 input 相同。

alpha

输入

标量,关键字参数;
表示累加时的缩放因子,默认为 1。

返回值说明#

原地操作返回 input

约束说明#

  1. index必须是整数类型(DT_INT32 或 DT_INT64),值不超过input在dim维度上的Shape大小,维数为1,Shape大小与 source 所在dim轴的Shape大小相同;

  2. dim为int类型,取值范围:-input.dim <= dim < input.dim;

  3. input和source的数据类型和维数均相同;

  4. input.shape和source.shape的dim轴viewshape不可切,要求viewshape[dim]>=max(input.shape[dim], source.shape[dim]),其余维度的Shape大小不做限制;

  5. TileShape的维度与input相同,input, source 的 dim 轴以及 index 均不可切,所有输入和输出的TileShape大小总和不能超过UB内存的大小。

调用示例#

TileShape设置示例#

调用该operation接口前,应通过set_vec_tile_shapes设置TileShape。

TileShape维度应和输出一致。

如输入input为[m, n, p],dim为1,输入source为[m, t, p],输入index为[t],输出为[m, n, p],TileShape设置为[m1, t1, p1],则m1, p1分别用于切分m, p轴。 n轴,t轴不可切,必须保证n轴t轴全载。

pypto.set_vec_tile_shapes(4, 16, 32)

接口调用示例#

x = pypto.tensor([2, 3], pypto.DT_INT32)        # shape (2, 3)
source = pypto.tensor([3, 3], pypto.DT_INT32)   # shape (3, 3)
index = pypto.tensor([3], pypto.DT_INT32)   # shape (3,)
dim = 0
# use alpha
y = pypto.index_add_(x, dim, index, source, alpha=1)
# not use alpha
y = pypto.index_add_(x, dim, index, source)

结果示例如下:

输入数据 x:   [[0 0 0],
               [0 0 0]]
      source: [[1 1 1],
               [1 1 1],
               [1 1 1]]
      index:   [0 1 0]
输出数据 y:   [[2 2 2],
               [1 1 1]]               # shape (2, 3)