pypto.set_cube_tile_shapes#
产品支持情况#
产品 |
是否支持 |
|---|---|
Ascend 950PR/Ascend 950DT |
√ |
Atlas A2 训练系列产品/Atlas A2 推理系列产品 |
√ |
Atlas A3 训练系列产品/Atlas A3 推理系列产品 |
√ |
功能说明#
在调用pypto.matmul或pypto.scaled_mm前必须调用本接口设置矩阵运算的切分大小,具体切分配置可参考Matmul高性能编程。
函数原型#
set_cube_tile_shapes(m: List[int], k: List[int], n: List[int], enable_split_k: bool = False) -> None
参数说明#
参数名 |
输入/输出 |
说明 |
|---|---|---|
m |
输入 |
m维度在L0和L1上的TileShape(切片形状)的切分大小,分别对应mL0和mL1的切分大小 |
k |
输入 |
k维度在L0和L1上的TileShape(切片形状)的切分大小,分别对应kL0和kL1的切分大小 |
n |
输入 |
n维度在L0和L1上的TileShape(切片形状)的切分大小,分别对应nL0和nL1的切分大小 |
enable_split_k |
输入 |
设置True表示使能matmul的多核切K功能(便捷开关,不保证性能最优);默认为False,表示未使能多核切K。性能调优推荐通过前端手动切K实现,详见Matmul高性能编程 |
返回值说明#
void
约束说明#
对齐约束
通用对齐约束
要求mL0、mL1、kL0、kL1、nL0、nL1均满足32字节对齐(DT_FP32输入场景要求满足16元素对齐)。例如:输入矩阵的数据类型为DT_FP16时,kL0 * sizeof(DT_FP16) % 32 == 0。
基础关系约束
约束项
要求
mL0 与 mL1
mL0 > 0且mL0 ≤ mL1且mL1 % mL0 == 0kL0 与 kL1
kL0 > 0且kL0 ≤ kL1且kL1 % kL0 == 0nL0 与 nL1
nL0 > 0且nL0 ≤ nL1且nL1 % nL0 == 0ND格式特有约束
当A矩阵在format为ND且转置场景时(即数据排布为[K, M]),要求mL0满足32字节对齐。
NZ格式特有约束
A、B矩阵在format为NZ场景时,要求外轴切分大小满足16元素对齐,内轴切分大小满足32字节对齐。例如,在A矩阵非转置场景,外轴为M、内轴为K,要求mL0、mL1满足16元素对齐,kL0、kL1满足32字节对齐。
空间约束
输入dtype为DT_FP16或DT_BF16或DT_FP32:
CeilAlign(mL0, 16) × CeilAlign(kL0, 16) × sizeof(aDtype) ≤ L0A_size CeilAlign(nL0, 16) × CeilAlign(kL0, 16) × sizeof(bDtype) ≤ L0B_size CeilAlign(mL0, 16) × CeilAlign(nL0, 16) × sizeof(cDtype) ≤ L0C_size CeilAlign(mL1, 16) × CeilAlign(kL1, 16) × sizeof(aDtype) + CeilAlign(nL1, 16) × CeilAlign(kL1, 16) × sizeof(bDtype) ≤ L1_size
输入dtype为DT_INT8或DT_FP8E5M2或DT_FP8E4M3或DT_HF8:
CeilAlign(mL0, 32) × CeilAlign(kL0, 32) × sizeof(aDtype) ≤ L0A_size CeilAlign(nL0, 32) × CeilAlign(kL0, 32) × sizeof(bDtype) ≤ L0B_size CeilAlign(mL0, 32) × CeilAlign(nL0, 32) × sizeof(cDtype) ≤ L0C_size CeilAlign(mL1, 32) × CeilAlign(kL1, 32) × sizeof(aDtype) + CeilAlign(nL1, 32) × CeilAlign(kL1, 32) × sizeof(bDtype) ≤ L1_size
Bias空间约束:
bias数据到达BTBuffer全部转为fp32,需满足以下约束:
nL0 × 4 ≤ BTBuffer_size
FixPipe空间约束:
scaleTensor数据为uint64_t,需满足以下约束:
nL0 × 8 ≤ FixBuffer_size
其中:
aDtype、bDtype为输入矩阵数据类型
cDtype为输出矩阵数据类型,当输入为DT_INT8时cDtype为DT_INT32,其余场景cDtype为DT_FP32
CeilAlign(value, align)元素对齐实现为:(value + align - 1) / align * align
多核切K约束
仅支持2维矩阵,3维/4维矩阵不支持多核切K
多核切K场景只支持out_dtype数据类型为DT_FP32或DT_INT32。
Bias、FixPipe反量化场景不支持叠加多核切K功能。
调用示例#
# 基本配置
pypto.set_cube_tile_shapes([128, 128], [128, 128], [128, 128])
# 启用多核切K(便捷开关,不保证性能最优;性能调优见 Matmul高性能编程)
pypto.set_cube_tile_shapes([128, 128], [64, 256], [128, 128], enable_split_k=True)