pypto.expand_clone#
产品支持情况#
产品 |
是否支持 |
|---|---|
Ascend 950PR/Ascend 950DT |
√ |
Atlas A3 训练系列产品/Atlas A3 推理系列产品 |
√ |
Atlas A2 训练系列产品/Atlas A2 推理系列产品 |
√ |
功能说明#
将输入Tensor在唯一等于1的轴上广播以匹配Shape,返回真实占内存的新Tensor。
函数原型#
expand_clone(
input: Tensor,
shape: List[int],
*,
valid_shape: Optional[List[Union[int, SymbolicScalar]]] = None
) -> Tensor
参数说明#
参数名 |
输入/输出 |
说明 |
|---|---|---|
input |
输入 |
源操作数。 |
shape |
输入 |
源操作数,目标Shape。 |
valid_shape |
输入 |
关键字参数。 |
返回值说明#
返回输出Tensor,其数据类型和input相同,形状为shape。
约束说明#
只能一维广播,输入Tensor被广播的轴的Shape大小要为1。
input的viewshape与 input 维度相同,viewshape[dim]=1,input[dim]=1, 其中dim为被拓展轴,其余维度不做限制。举例如下:
[a,1] 拓展到[a,5],其中dim=1,表示在dim 1 上进行拓展。
len(viewshape)=2 并且 viewshape[dim]=1
关于 valid_shape 的说明:
在动态图场景中,假设Tensor input [a,1] 扩展到 [a,5],并设置 ViewShape 为 [a,2],框架会通过 pypto.loop 循环生成 [a,2] 分块,并按偏移量拼接。此时若未传入 valid_shape,代码将默认生成全 [a,2] 的Tensor(如 pypto.expand_clone(input, [a,2]))。
然而,当总尺寸 [a,5] 无法被分块尺寸 [a,2] 整除时,尾块的有效形状(如 [a,1])无法由框架自动推导。例如,最后一列可能仅包含 1 个元素,而非完整的 [a,2] 分块。此时必须通过 valid_shape 明确指定尾块的实际有效Shape,如下:
pypto.expand_clone(input, [a,2], valid_shape = [a, pypto.min(2, 5 - 2 * b_idx),)
其中b_idx 表示循环索引。
tileshape的维度与result 维度相同,用于切分 result。
tileshape 的大小形状无额外约束,只需保证不超过ub size。
调用示例#
TileShape设置示例#
调用该operation接口前,应通过set_vec_tile_shapes设置TileShape。
TileShape维度应和输出一致。
如输入intput shape为[m, 1],输出为[m, n], TileShape设置为[m1, n1], 则m1, n1分别用于切分m, n轴。
pypto.set_vec_tile_shapes(4, 16)
接口调用示例#
# static graph
a = pypto.tensor([1,8], pypto.DT_INT32)
out1 = pypto.expand_clone(a, [4,8])
# dynamic graph
out2 = pypto.expand_clone(a, [4,8], valid_shape = [pypto.symbolic_scalar(4), pypto.symbolic_scalar(8)])
结果示例如下:
输入数据a: [[1, 2, 3, 4, 5, 6, 7, 8]]
输出数据out1: [[1, 2, 3, 4, 5, 6, 7, 8],
[1, 2, 3, 4, 5, 6, 7, 8],
[1, 2, 3, 4, 5, 6, 7, 8],
[1, 2, 3, 4, 5, 6, 7, 8]]
输出数据out2: [[1, 2, 3, 4, 5, 6, 7, 8],
[1, 2, 3, 4, 5, 6, 7, 8],
[1, 2, 3, 4, 5, 6, 7, 8],
[1, 2, 3, 4, 5, 6, 7, 8]]