pypto.scaled_mm#
产品支持情况#
产品 |
是否支持 |
|---|---|
Ascend 950PR/Ascend 950DT |
√ |
功能说明#
实现mat_a 、mat_b矩阵的mx量化矩阵乘运算,计算公式为:out = (mat_a * scale_a) @ (mat_b * scale_b)
mat_a 、mat_b 、scale_a 、scale_b为源操作数,mat_a 为左矩阵;mat_b为右矩阵;scale_a为左矩阵量化参数;scale_b为右矩阵量化参数
out 为目的操作数,存放矩阵乘结果的矩阵
函数原型#
scaled_mm(mat_a, mat_b, out_dtype, scale_a, scale_b, *, a_trans = False, b_trans = False, scale_a_trans = False, scale_b_trans = False, c_matrix_nz = False, extend_params=None) -> Tensor
参数说明#
参数名 |
输入/输出 |
说明 |
|---|---|---|
mat_a |
输入 |
表示输入左矩阵。不支持输入空Tensor。 |
mat_b |
输入 |
表示输入右矩阵。不支持输入空Tensor。 |
out_dtype |
输出 |
表示输出矩阵数据类型,支持DT_FP32,DT_FP16,DT_BF16。 |
scale_a |
输入 |
表示输入左矩阵量化参数。不支持输入空Tensor。 |
scale_b |
输入 |
表示输入右矩阵量化参数。不支持输入空Tensor。 |
a_trans |
输入 |
参数a_trans表示输入左矩阵是否转置,默认为False。 |
b_trans |
输入 |
参数b_trans表示输入右矩阵是否转置,默认为False。 |
scale_a_trans |
输入 |
参数scale_a_trans表示输入左矩阵量化参数是否转置,默认为False。 |
scale_b_trans |
输入 |
参数scale_b_trans表示输入右矩阵量化参数是否转置,默认为False。 |
c_matrix_nz |
输入 |
参数c_matrix_nz表示输出矩阵的Format是否采用NZ格式,默认为False,当前仅支持设置False,即输出矩阵仅支持ND格式。 |
extend_params |
输入 |
支持bias及fixpipe的反量化功能,数据类型为字典格式。默认为None,当前仅支持bias场景。详见下表 |
表2:extend_params参数说明
参数名 |
说明 |
|---|---|
bias_tensor |
表示偏置矩阵。 |
返回值说明#
返回值为out 矩阵(Tensor)。
约束说明#
调用matmul接口前需要通过pypto.set_cube_tile_shapes设置M、N、K轴上的切分大小。
调用matmul接口的输入为调用pypto.reshape后的NZ格式时,需要调用pypto.set_matrix_size接口设置pypto.reshape前的输入到matmul的原始Shape的m,k,n值。
调用示例#
mat_a = pypto.tensor([64, 128], pypto.DT_FP8E5M2, "mat_a")
mat_b = pypto.tensor([128, 32], pypto.DT_FP8E5M2, "mat_b")
scale_a = pypto.tensor([64, 2, 2], pypto.DT_FP8E8M0, "scale_a")
scale_b = pypto.tensor([2, 32, 2], pypto.DT_FP8E8M0, "scale_b")
out1 = pypto.scaled_mm(mat_a, mat_b, pypto.DT_BF16, scale_a, scale_b)
mat_a = pypto.tensor([128, 64], pypto.DT_FP8E5M2, "mat_a")
mat_b = pypto.tensor([32, 128], pypto.DT_FP8E5M2, "mat_b")
scale_a = pypto.tensor([2, 64, 2], pypto.DT_FP8E8M0, "scale_a")
scale_b = pypto.tensor([32, 2, 2], pypto.DT_FP8E8M0, "scale_b")
bias = pypto.tensor((1, 32), pypto.DT_FP16, "tensor_bias")
extend_params = {'bias_tensor': bias}
out1 = pypto.scaled_mm(mat_a, mat_b, pypto.DT_BF16, scale_a, scale_b, scale_a_trans=True, scale_b_trans=True, extend_params=extend_params)