pypto.scatter_#
产品支持情况#
产品 |
是否支持 |
|---|---|
Ascend 950PR/Ascend 950DT |
√ |
Atlas A3 训练系列产品/Atlas A3 推理系列产品 |
√ |
Atlas A2 训练系列产品/Atlas A2 推理系列产品 |
√ |
功能说明#
将src的值写入input中。写入位置由index指定。3维计算公式如下,其他维度以此类推:
src为固定标量时:
$\(
\begin{cases}
input\left[ index\left[i\right]\left[j\right]\left[k\right] \right]\left[j\right]\left[k\right] = src & \text{if } dim = 0 \\
input\left[i\right]\left[ index\left[i\right]\left[j\right]\left[k\right] \right]\left[k\right] = src & \text{if } dim = 1 \\
input\left[i\right]\left[j\right]\left[ index\left[i\right]\left[j\right]\left[k\right] \right] = src & \text{if } dim = 2
\end{cases}
\)\(
<br> src为Tensor时:
\)\(
\begin{cases}
input\left[ index\left[i\right]\left[j\right]\left[k\right] \right]\left[j\right]\left[k\right] = src\left[i\right]\left[j\right]\left[k\right] & \text{if } dim = 0 \\
input\left[i\right]\left[ index\left[i\right]\left[j\right]\left[k\right] \right]\left[k\right] = src\left[i\right]\left[j\right]\left[k\right] & \text{if } dim = 1 \\
input\left[i\right]\left[j\right]\left[ index\left[i\right]\left[j\right]\left[k\right] \right] = src\left[i\right]\left[j\right]\left[k\right] & \text{if } dim = 2
\end{cases}
\)$
函数原型#
scatter_(input: Tensor, dim: int, index: Tensor, src: Union[float, Element, Tensor], *, reduce: str = None) -> Tensor
参数说明#
参数名 |
输入/输出 |
说明 |
|---|---|---|
input |
输入 |
支持的类型为:Tensor。 |
dim |
输入 |
指定用于索引的维度,支持input的维度范围内的任意维度。 |
index |
输入 |
input的一组索引。 |
src |
输入 |
src是更新的标量或Tensor。 |
reduce |
输入 |
要应用的归约操作,支持 ‘add’ 或 ‘multiply’,不传参时默认为直接替换 |
返回值说明#
返回更新后的 input,为inplace操作。
约束说明#
broadcast约束:input和index不支持broadcast;
input.shape的dim轴不可切,viewshape的维度与input维度相同,要求viewshape[dim] >= max( input.shape[dim], index.shape[dim] ),其余维度的Shape大小不做限制;
input.shape的dim轴不可切,tileshape的维度与input维度相同,tileshape[dim] >= viewshape[dim],其余维度的Shape大小不做限制,input index 和 result 都会在 UB 中,需满足所有输入和输出的 tileshape 大小总和不能超过UB内存的大小。
input.shape和index.shape的非dim轴切分,需满足viewshape[non dim]切分后,input和index的非dim轴切分块数相同。tileshape切分时,也需要保证input和index的非dim轴切分块数相同。
src为Tensor,dim为尾轴,reduce为None,且当index每行数据内存在不唯一索引时,行为是不确定的,将从src中任意选择一个值
调用示例#
TileShape设置示例#
调用该operation接口前,应通过set_vec_tile_shapes设置TileShape。
TileShape维度应和输出一致。
如输入intput shape为[a, b, c],dim为1,index为[m, t, p](其中m<=a, p<=c),src为[x, y, z](其中x>=m, y>=t, z>=p),输出为[a, b, c], TileShape设置为[m1, t1, p1]。 则m1, p1分别用于切分m, p轴。t1必须大于等于b和t,dim对应轴不可切,必须保证b轴和t轴全载。
pypto.set_vec_tile_shapes(4, 16, 32)
接口调用示例#
将2维 input 根据2维index更新对应索引的值
x = pypto.tensor([3, 5], pypto.DT_FP32) y = pypto.tensor([2, 2], pypto.DT_INT64) o = pypto.scatter_(x, 0, y, 2.0)
结果示例如下:
输入数据x:[[0 0 0 0 0], [0 0 0 0 0], [0 0 0 0 0]] 输入数据y:[[1 2], [0 1]] 输出数据o:[[2.0 0 0 0 0], [2.0 2.0 0 0 0], [0 2.0 0 0 0]]