pypto.transdata#
产品支持情况#
产品 |
是否支持 |
|---|---|
Atlas A3 训练系列产品/Atlas A3 推理系列产品 |
√ |
Atlas A2 训练系列产品/Atlas A2 推理系列产品 |
√ |
功能说明#
进行Tensor数据格式转换。 支持以下转换:
NCHW2NC1HWC0
NCHW2Fractal_Z
NC1HWC02NCHW
NCDHW2NDC1HWC0
NCDHW2Fractal_Z_3D
NDC1HWC02NCDHW
函数原型#
transdata(input: Tensor, transDataType: TileOpFormat, group: int = 1) -> Tensor:
参数说明#
参数名 |
输入/输出 |
说明 |
|---|---|---|
input |
输入 |
源操作数。 支持的类型为:Tensor。 Tensor支持的数据类型为:DT_FP16,DT_BF16,DT_INT8,DT_INT16,DT_INT32,DT_FP32。 不支持空Tensor;Shape仅支持4-5维, 具体与transDataType相关;Shape Size不大于2147483647(即INT32_MAX)。 |
transDataType |
输入 |
转换类型, 支持的类型为:TileOpFormat。 (NDC1HWC02NCDHW 和 NCHW2NC1HWC0 转换类型相同,均为 TILEOP_ND) |
group |
输入 |
分组数, 支持的类型为:int,默认为1 |
返回值说明#
输出Tensor Shape与transDataType相关。
约束说明#
transDataType:指定转换数据类型,必须为TileOpFormat类型;
TileShape的维度与input相同,尾轴32B对齐,所有输入和输出的TileShape大小总和不能超过UB内存的大小;
暂时不支持pad场景;
N0 = 16, C0为当前数据类型32B对齐最小元素个数,例:pypto.DT_INT32的情况下C0为8;
其余约束如下。
转换场景 |
tileshape约束 |
|---|---|
NCHW2NC1HWC0 |
C轴是C0整数倍,group是C轴的因子 |
NCHW2Fractal_Z |
N轴是N0整数倍, C轴是C0整数倍,group是N轴的因子 |
NC1HWC02NCHW |
W轴32B对齐, C0轴不切分, group为NCHW2NC1HWC0的group |
NCDHW2NDC1HWC0 |
C轴是C0整数倍,group是C轴的因子,暂时不支持int8类型 |
NCDHW2Fractal_Z_3D |
N轴是N0整数倍, C轴是C0整数倍,group是N轴的因子 |
NDC1HWC02NCDHW |
W轴32B对齐, C0轴不切分, group为NCDHW2NDC1HWC0的group,暂时仅支持N=1情况,不支持int8类型 |
调用示例#
TileShape设置示例#
调用该operation接口前,应通过set_vec_tile_shapes设置TileShape。
TileShape维度应和输入一致。
例如:输入input shape为[N, C, H, W],TileShape设置为[n,c, h, w],则n,c,h,w分别用于切分N, C, H, W轴。
pypto.set_vec_tile_shapes(4, 16, 16, 16)
接口调用示例#
x = pypto.tensor([1, 16, 1, 8], pypto.DT_INT32) # shape [1, 16, 1, 8]
y = pypto.transdata(x, transDataType, group=1)
结果示例如下:
输入数据 x: [[[[ 0, 1, 2, 3, 4, 5, 6, 7],
[ 8, 9, 10, 11, 12, 13, 14, 15]],
[[ 16, 17, 18, 19, 20, 21, 22, 23],
[ 24, 25, 26, 27, 28, 29, 30, 31]],
[[ 32, 33, 34, 35, 36, 37, 38, 39],
[ 40, 41, 42, 43, 44, 45, 46, 47]],
[[ 48, 49, 50, 51, 52, 53, 54, 55],
[ 56, 57, 58, 59, 60, 61, 62, 63]],
[[ 64, 65, 66, 67, 68, 69, 70, 71],
[ 72, 73, 74, 75, 76, 77, 78, 79]],
[[ 80, 81, 82, 83, 84, 85, 86, 87],
[ 88, 89, 90, 91, 92, 93, 94, 95]],
[[ 96, 97, 98, 99, 100, 101, 102, 103],
[104, 105, 106, 107, 108, 109, 110, 111]],
[[112, 113, 114, 115, 116, 117, 118, 119],
[120, 121, 122, 123, 124, 125, 126, 127]]]]
输出数据 y: [[[[[ 0, 64],
[ 1, 65],
[ 2, 66],
[ 3, 67],
[ 4, 68],
[ 5, 69],
[ 6, 70],
[ 7, 71]],
[[ 8, 72],
[ 9, 73],
[ 10, 74],
[ 11, 75],
[ 12, 76],
[ 13, 77],
[ 14, 78],
[ 15, 79]],
[[ 16, 80],
[ 17, 81],
[ 18, 82],
[ 19, 83],
[ 20, 84],
[ 21, 85],
[ 22, 86],
[ 23, 87]],
[[ 24, 88],
[ 25, 89],
[ 26, 90],
[ 27, 91],
[ 28, 92],
[ 29, 93],
[ 30, 94],
[ 31, 95]],
[[ 32, 96],
[ 33, 97],
[ 34, 98],
[ 35, 99],
[ 36, 100],
[ 37, 101],
[ 38, 102],
[ 39, 103]],
[[ 40, 104],
[ 41, 105],
[ 42, 106],
[ 43, 107],
[ 44, 108],
[ 45, 109],
[ 46, 110],
[ 47, 111]],
[[ 48, 112],
[ 49, 113],
[ 50, 114],
[ 51, 115],
[ 52, 116],
[ 53, 117],
[ 54, 118],
[ 55, 119]],
[[ 56, 120],
[ 57, 121],
[ 58, 122],
[ 59, 123],
[ 60, 124],
[ 61, 125],
[ 62, 126],
[ 63, 127]]]]]