pypto.scatter_update#
产品支持情况#
产品 |
是否支持 |
|---|---|
Ascend 950PR/Ascend 950DT |
√ |
Atlas A3 训练系列产品/Atlas A3 推理系列产品 |
√ |
Atlas A2 训练系列产品/Atlas A2 推理系列产品 |
√ |
功能说明#
功能1:原地操作,将4维src根据2维索引index更新到4维 input 上,计算公式如下:
功能2:原地操作,将2维src根据2维index更新到2维 input 上,计算公式如下(其中s是index第二维的大小,即index.shape[1]):
函数原型#
scatter_update(input: Tensor, dim: int, index: Tensor, src: Tensor) -> Tensor
参数说明#
参数名 |
输入/输出 |
说明 |
|---|---|---|
input |
输入 |
支持的类型为:Tensor。 |
dim |
输入 |
请保持默认值-2。 |
index |
输入 |
input的一组索引。 |
src |
输入 |
src是一组更新值。 |
返回值说明#
返回更新后的 input,为inplace操作。
约束说明#
broadcast 约束:不支持 broadcast。
ViewShape 约束:2维场景下ViewShape为[viewB * s, d],4维场景下ViewShape为[viewB, viewS, 1, d],尾轴d不做切分。2维场景下,[viewB * s, d]针对src做切分,其第0维是index的第1维s的倍数,[viewB, S]针对index做切分。4维场景下,[viewB, viewS, 1, d]针对src做切分,[viewB, viewS]针对index做切分。
TileShape 约束:2维场景下TileShape为[tileS, d],4维场景下TileShape为[tileB, tileS, 1, d]。尾轴d不做切分。2维场景下,TileShape针对src做切分,[1,tileS]针对index做切分,tileBS是输入index的第1维s的约数,如src为[12, 64],index为[3, 4],TileShape为[TileS, 64],其中,TileS可以是1、2、4。4维场景下,TileShape针对src做切分,并且,[tileB, tileS]针对index做切分。由于TileShape的切分针对src和index,切块大小之和应小于UB限制。
二维示例: input:[15, 8],index:[5, 2],src:[10, 8], viewShape: [viewB * s, 8], viewB需要是整数,即第0维是s的倍数,tileShape:[tileS, 8], tileS需要是s的约数即1或者2。
调用示例#
TileShape设置示例#
调用该operation接口前,应通过set_vec_tile_shapes设置TileShape。
TileShape维度应和输入src一致。
输入input和输出均在gm上,不涉及tile切分。输入index和输入src需要搬入ub,涉及tile切分。
如输入input为[t, d], dim为-2,输入index为[b, s],输入src为[bs, d], 其中bs=b*s,输出为[t, d], TileShape设置为[bs1, d1], 则bs1用于切分bs轴, d轴不允许切分, d1必须和d相等。
pypto.set_vec_tile_shapes(16, 64)
接口调用示例#
将2维 src 根据2维index更新到2维input上, 注意原地操作的写法, 即等号左边的输出应和输入input相同:
x = pypto.tensor([8, 3], pypto.DT_INT32) y = pypto.tensor([2, 2], pypto.DT_INT64) z = pypto.tensor([4, 3], pypto.DT_INT32) x = pypto.scatter_update(x, -2, y, z)
结果示例如下:
输入数据x:[[0 0 0], [0 0 0], [0 0 0], [0 0 0], [0 0 0], [0 0 0], [0 0 0], [0 0 0]] 输入数据y:[[1 2], [4 5]] 输入数据z:[[1 2 3], [4 5 6], [7 8 9], [10 11 12]] 输出数据x:[[0 0 0], [1 2 3], [4 5 6], [0 0 0], [7 8 9], [10 11 12], [0 0 0], [0 0 0]]
将4维src根据2维索引index更新到4维input上, 注意原地操作的写法, 即等号左边的输出应和输入input相同:
x = pypto.tensor([2, 6, 1, 3], pypto.DT_INT32) y = pypto.tensor([2, 2], pypto.DT_INT64) z = pypto.tensor([2, 2, 1, 3], pypto.DT_INT32) x = pypto.scatter_update(x, -2, y, z)
结果示例如下:
输入数据x:[[ [[0 0 0]], [[0 0 0]], [[0 0 0]], [[0 0 0]], [[0 0 0]], [[0 0 0]], ], [ [[0 0 0]], [[0 0 0]], [[0 0 0]], [[0 0 0]], [[0 0 0]], [[0 0 0]], ]] 输入数据y:[[1 8], [4 10]] 输入数据z:[[ [[1 2 3]], [[4 5 6]], ], [ [[7 8 9]], [[10 11 12]], ]] 输出数据x:[[ [[0 0 0]], [[1 2 3]], [[0 0 0]], [[0 0 0]], [[7 8 9]], [[0 0 0]], ], [ [[0 0 0]], [[0 0 0]], [[4 5 6]], [[0 0 0]], [[10 11 12]], [[0 0 0]], ]]