pypto.concat#
产品支持情况#
产品 |
是否支持 |
|---|---|
Ascend 950PR/Ascend 950DT |
√ |
Atlas A3 训练系列产品/Atlas A3 推理系列产品 |
√ |
Atlas A2 训练系列产品/Atlas A2 推理系列产品 |
√ |
功能说明#
将输入的多个Tensor沿指定维度(dim)拼接,返回一个拼接后的Tensor。
函数原型#
concat(tensors: List[Tensor], dim: int = 0) -> Tensor
参数说明#
参数名 |
输入/输出 |
说明 |
|---|---|---|
tensors |
输入 |
源操作数。 |
dim |
输入 |
源操作数。 |
返回值说明#
返回输出Tensor,Tensor的数据类型和 tensors 的任一tensor数据类型相同,Shape与 tensors任一tensor相同(除 dim 对应维度),dim 对应维度为 tensors 各个 tensor 对应维度之和。
约束说明#
1.源操作数 tensors 的大小需要大于等于 2,即 len(tensors )>=2;小于等于128。(支持输入一个tensor情况,精度暂时不保证);
2.输入 tensor 数据类型相同、维度数量相同,并且除待拼接维度(dim)之外的每个维度值相同;
3.dim: -input.dim <= dim < input.dim(input对应tensors 的任一tensor);
4.设置viewshape时,dim对应维度不切块(即viewshape对应值>=tensors 任一tensor的对应值)。
调用示例#
TileShape设置示例#
调用该operation接口前,应通过set_vec_tile_shapes设置TileShape。
TileShape维度应和输出一致。
如输入tensors维度为[m, c1, p],[m, c2, p],输出为[m, c1+c2, p],TileShape设置为[m1, n1, p1],则m1, p1分别用于切分m, p轴,n1用于切分c1和c2轴。
pypto.set_vec_tile_shapes(4, 16, 32)
接口调用示例#
a = pypto.tensor([2, 2], pypto.DT_FP32) # 2x2 tensor with all 1s
b = pypto.tensor([2, 2], pypto.DT_FP32) # 2x2 tensor with all 0s
out = pypto.concat([a, b], dim = 0)
结果示例如下:
输入数据a: [[1.0 1.0],
[1.0 1.0]]
输入数据b: [[0.0 0.0],
[0.0 0.0]]
输出数据out: [[1.0 1.0],
[1.0 1.0],
[0.0 0.0],
[0.0 0.0]]