pypto.matmul#

产品支持情况#

产品

是否支持

Ascend 950PR/Ascend 950DT

Atlas A3 训练系列产品/Atlas A3 推理系列产品

Atlas A2 训练系列产品/Atlas A2 推理系列产品

功能说明#

实现input 、mat2矩阵的矩阵乘运算,计算公式为:out = input @ mat2

  • input 、mat2为源操作数,input 为左矩阵;mat2为右矩阵

  • out 为目的操作数,存放矩阵乘结果的矩阵

注意事项#

  • 左右矩阵数据类型必须一致:matmul 的左右矩阵数据类型必须相同(如 BF16+BF16、FP16+FP16),不支持混合输入(如 BF16+FP32)

  • 推荐使用低精度输入:BF16/FP16 输入直接 matmul 输出 FP32,比先 cast 到 FP32 再 matmul 性能更好,且精度相当

  • 避免不必要的 cast:将 BF16 升级到 FP32 再进行 matmul 计算不会有精度提升,反而会产生额外的数据搬移开销

  • 利用随路 transpose:matmul 支持 a_transb_trans 参数,可以在矩阵乘时随路完成转置,避免额外调用 transpose 操作

  • 必须先设置 TileShape:调用 matmul 接口前需要通过 set_cube_tile_shapes 设置 M、N、K 轴上的切分大小

函数原型#

matmul(input, mat2, out_dtype, *, a_trans = False, b_trans = False, c_matrix_nz = False, extend_params=None) -> Tensor

参数说明#

表1:API参数说明

参数名

输入/输出

说明

input

输入

表示输入左矩阵。不支持输入空Tensor。
支持的数据类型为:DT_INT8, DT_FP16, DT_BF16,DT_FP32,DT_HF8,DT_FP8E5M2,DT_FP8E4M3。当input的数据类型为DT_FP8E5M2时,mat2的数据类型可以为DT_FP8E5M2或DT_FP8E4M3,当input的数据类型为DT_FP8E4M3时,mat2的数据类型可以为DT_FP8E5M2或DT_FP8E4M3,其余数据类型,需保证左右矩阵一致。
支持的矩阵维度:2维、3维、4维,且左右矩阵维度需保持一致。
输入矩阵支持的Format为:TILEOP_ND, TILEOP_NZ(DT_FP32,DT_FP8E5M2输入不支持TILEOP_NZ格式)。
当Format为TILEOP_ND(ND格式)时,外轴范围为[1, 2^31 - 1],内轴范围为[1, 65535]。
当Format为TILEOP_NZ(NZ格式)时,其Shape维度需满足内轴32字节对齐(当输出矩阵数据类型为DT_INT32时,内轴为16元素对齐),外轴16元素对齐。
内轴外轴:当输入矩阵input非转置时,对应数据排布为[M, K],此时外轴为M,内轴为K;当输入矩阵input转置时,对应数据排布为[K, M],此时外轴为K,内轴为M;
在使用pypto.view接口的场景,应保证传入View的Shape维度也满足内轴32字节对齐(当输出矩阵数据类型为DT_INT32时,内轴为16元素对齐),外轴16元素对齐
当矩阵维度为3维或者4维时,不支持pypto.view场景。

mat2

输入

表示输入右矩阵。不支持输入空Tensor。
支持的数据类型为:DT_INT8, DT_FP16, DT_BF16,DT_FP32,DT_HF8,,DT_FP8E5M2,DT_FP8E4M。
支持的矩阵维度:2维、3维、4维,且左右矩阵维度需保持一致。
输入矩阵支持的Format为:TILEOP_ND, TILEOP_NZ(DT_FP32,DT_FP8E5M2输入不支持TILEOP_NZ格式)。
当Format为TILEOP_ND(ND格式)时,外轴范围为[1, 2^31 - 1],内轴范围为[1, 65535]。
当Format为TILEOP_NZ(NZ格式)时,其Shape维度需满足内轴32字节对齐(当输出矩阵数据类型为DT_INT32时,内轴为16元素对齐),外轴16元素对齐。
内轴外轴:当输入矩阵mat2非转置时,对应数据排布为[K, N],此时外轴为K,内轴为N;当输入矩阵mat2转置时,对应数据排布为[N, K],此时外轴为N,内轴为K;
在使用pypto.view接口的场景,应保证传入View的Shape维度也满足内轴32字节对齐(其中当输出矩阵数据类型为DT_INT32时,内轴为16元素对齐),外轴16元素对齐。
当矩阵维度为3维或者4维时,不支持pypto.view场景。

out_dtype

输出

表示输出矩阵数据类型,支持DT_FP32,DT_FP16,DT_BF16, DT_INT32。
- 输入矩阵数据类型为DT_FP16时,out_dtype可选DT_FP32, DT_FP16。
- 输入矩阵数据类型为DT_BF16时,out_dtype可选DT_FP32, DT_BF16。
- 输入矩阵数据类型为DT_INT8时,out_dtype可选DT_INT32。
- 输入矩阵数据类型为DT_FP32时,out_dtype可选DT_FP32。
- 输入矩阵数据类型为DT_FP8E5M2,DT_FP8E4M3时,out_dtype可选DT_FP16,DT_BF16,DT_FP32。

a_trans

输入

参数a_trans表示输入左矩阵是否转置,默认为False。

b_trans

输入

参数b_trans表示输入右矩阵是否转置,默认为False。

c_matrix_nz

输入

参数c_matrix_nz表示输出矩阵的Format是否采用NZ格式,默认为False,当前仅支持设置False,即输出矩阵仅支持ND格式。

extend_params

输入

支持bias、fixpipe反量化及TF32舍入模式TransMode功能,详见下表。
- 数据类型为字典格式。
- 此参数与其内部参数均为可选参数。

表2:extend_params参数说明

参数名

说明

scale

表示pertensor量化场景(使用同一个缩放因子将高精度数映射到低精度数)输出矩阵反量化的参数。
输入为float类型,取1位符号位 + 8位指数位 + 10位尾数位参与运算。

scale_tensor

表示perchannel量化场景(对每一个输出通道独立计算一套量化参数)输出矩阵反量化的矩阵。
scale_tensor输入固定为uint64_t 的Tensor。计算时会转换uint64_t为float类型的低32位bit后,取1位符号位 + 8位指数位 + 10位尾数位参与运算。
scale_tensor的第一维度必须置1,且N维度需要与mat2矩阵的N维度相等。
scale_tensor只支持ND格式。
仅支持矩阵维度为2维场景。

bias_tensor

表示偏置矩阵。
输入为Tensor类型。
输入左右矩阵数据类型为DT_FP16时,Bias矩阵数据类型可选DT_FP16和DT_FP32。
输入左右矩阵数据类型为DT_FP32时,Bias矩阵数据类型只能为DT_FP32。
输入左右矩阵数据类型为DT_INT8时,Bias矩阵数据类型只能为DT_INT32。
输入左右矩阵数据类型为DT_FP8E5M2,DT_FP8E4M3时,Bias矩阵数据类型只能为DT_FP32。
bias_tensor只支持ND格式。
bias_tensor的第一维度应置1,且N维度需要与mat2矩阵的N维度相等。
Bias不支持多核切K功能。
仅支持矩阵维度为2维场景。

relu_type

表示输出矩阵是否进行ReLu操作。
输入为ReLuType类型。
支持RELU和NO_RELU两种模式。
仅支持矩阵维度为2维场景。

trans_mode

表示是否使能TF32计算及TF32舍入模式。
输入为TransMode类型。
CAST_NONE:不使能float数据类型转换为TF32数据类型。
CAST_RINT:使能float数据类型转换为TF32数据类型,舍入规则:舍入到最近整数,中间值时舍入到偶数。
CAST_ROUND:使能float数据类型转换为TF32数据类型,舍入规则:舍入到最近整数,中间值时远离零舍入。
仅支持输入左右矩阵和输出矩阵数据类型均为DT_FP32时设置。
仅支持矩阵维度为2维场景。

返回值说明#

返回值为out 矩阵(Tensor)。

约束说明#

  • Atlas A2 训练系列产品/Atlas A2 推理系列产品:不支持DT_HF8,DT_FP8E5M2,DT_FP8E4M3,不支持extend_params中的trans_mode参数。

  • Atlas A3 训练系列产品/Atlas A3 推理系列产品:不支持DT_HF8,DT_FP8E5M2,DT_FP8E4M3,不支持extend_params中的trans_mode参数。

  • 调用matmul接口前需要通过pypto.set_cube_tile_shapes设置M、N、K轴上的切分大小

  • 当矩阵维度为3维或者4维时,需要调用pypto.set_vec_tile_shapes接口设置vector的TileShape切分,如未设置,接口内部会设置2维的vec_tile_shape,其值为128,128。

  • 调用matmul接口的输入为调用pypto.reshape后的NZ格式时,需要调用pypto.set_matrix_size接口设置pypto.reshape前的输入到matmul的原始Shape的m,k,n值。

  • 调用matmul接口的输入矩阵维度为3维/4维并且数据格式为NZ格式时,需要调用pypto.set_matrix_size接口设置输入到matmul的原始Shape的m,k,n值。

调用示例#

a1 = pypto.tensor([16, 32], pypto.DT_BF16, "tensor_a")
b1 = pypto.tensor([32, 64], pypto.DT_BF16, "tensor_b")
out1 = pypto.matmul(a1, b1, pypto.DT_BF16)

a2 = pypto.tensor((2, 16, 32), pypto.DT_FP16, "tensor_a")
b2 = pypto.tensor((2, 32, 16), pypto.DT_FP16, "tensor_b")
out2 = pypto.matmul(a2, b2, pypto.DT_FP16)

a3 = pypto.tensor((1, 32, 64), pypto.DT_FP32, "tensor_a")
b3 = pypto.tensor((3, 64, 16), pypto.DT_FP32, "tensor_b")
out3 = pypto.matmul(a3, b3, pypto.DT_FP32)

a = pypto.tensor((16, 32), pypto.DT_FP16, "tensor_a")
b = pypto.tensor((32, 64), pypto.DT_FP16, "tensor_b")
bias = pypto.tensor((1, 64), pypto.DT_FP16, "tensor_bias")
extend_params = {'bias_tensor': bias}
pypto.matmul(a, b, pypto.DT_FP32, a_trans=False, b_trans=False, c_matrix_nz=False, extend_params=extend_params)

a = pypto.tensor((16, 32), pypto.DT_INT8, "tensor_a")
b = pypto.tensor((32, 64), pypto.DT_INT8, "tensor_b")
extend_params = {'scale': 0.2}
pypto.matmul(a, b, pypto.DT_BF16, a_trans=False, b_trans=False, c_matrix_nz=False, extend_params=extend_params)

 a = pypto.tensor((16, 32), pypto.DT_INT8, "tensor_a")
 b = pypto.tensor((32, 64), pypto.DT_INT8, "tensor_b")
 extend_params = {'scale': 0.2, 'relu_type': pypto.ReLuType.RELU}
 pypto.matmul(a, b, pypto.DT_BF16, a_trans=False, b_trans=False, c_matrix_nz=False, extend_params=extend_params)

 a = pypto.tensor((16, 32), pypto.DT_INT8, "tensor_a")
 b = pypto.tensor((32, 64), pypto.DT_INT8, "tensor_b")
 scale_tensor = pypto.tensor((1, 64), pypto.DT_UINT64, "tensor_scale")
 extend_params = {'scale_tensor': scale_tensor, 'relu_type': pypto.ReLuType.RELU}
 pypto.matmul(a, b, pypto.DT_BF16, a_trans=False, b_trans=False, c_matrix_nz=False, extend_params=extend_params)