pypto.set_matrix_size#

产品支持情况#

产品

是否支持

Atlas A3 训练系列产品/Atlas A3 推理系列产品

Atlas A2 训练系列产品/Atlas A2 推理系列产品

功能说明#

将NZ的输入Tensor在经过reshape后,传入matmul使用计算时,需要将原始的Tensor(未reshape前)的m,k,n值传入,使matmul获取到原始m,k,n值。

函数原型#

set_matrix_size(size: List[int])-> None

参数说明#

参数名

输入/输出

说明

size

输入

输入Tensor的m,k,n值

返回值说明#

void

约束说明#

1、NZ输入的Tensor,在经过reshape后,调用matmul计算时,需设置该参数。

2、调用matmul的输入是3维/4维的NZ格式Tensor,需设置该参数。

调用示例#

a = pypto.tensor((1, 32, 64), pypto.DT_FP32, "tensor_a")
b = pypto.tensor((3, 64, 16), pypto.DT_FP32, "tensor_b")
pypto.set_matrix_size([32, 64, 16]) #对应输入的Tensor的m,k,n值
out = pypto.matmul(a, b, pypto.DT_FP32)