pypto.experimental.transposed_batchmatmul#

产品支持情况#

产品

是否支持

Ascend 950PR/Ascend 950DT

Atlas A3 训练系列产品/Atlas A3 推理系列产品

Atlas A2 训练系列产品/Atlas A2 推理系列产品

功能说明#

该接口为定制接口,约束较多。不保证稳定性。

该算子执行转置批矩阵乘法。具体操作为:

  1. 将输入张量 tensor_a 从形状 (M, B, K) 转置为 (B, M, K)。

  2. 执行批矩阵乘法,将转置后的 tensor_a (B, M, K) 与 tensor_b (B, K, N) 相乘,得到中间结果 (B, M, N)。

  3. 将中间结果转置回形状 (M, B, N) 作为最终输出。

函数原型#

transposed_batchmatmul(tensor_a: Tensor, tensor_b: Tensor, out_dtype: dtype) -> Tensor

参数说明#

参数名

输入/输出

说明

tensor_a

输入

左侧输入张量。
支持的数据类型为:DT_FP16, DT_BF16。
不支持空Tensor,支持三维。
形状必须为 (M, B, K)。

tensor_b

输入

右侧输入张量。
支持的数据类型为:DT_FP16, DT_BF16。
不支持空Tensor,支持三维。
形状必须为 (B, K, N)。

out_dtype

输入

输出张量的数据类型。
支持的数据类型为:DT_FP16, DT_BF16。

返回值说明#

返回输出 Tensor,Tensor 的数据类型由 out_dtype 指定,形状为 (M, B, N)。

调用示例#

import pypto

# 创建输入张量
a = pypto.tensor((16, 2, 32), pypto.DT_FP16, "tensor_a")
b = pypto.tensor((2, 32, 64), pypto.DT_FP16, "tensor_b")

# 调用算子
c = pypto.experimental.transposed_batchmatmul(a, b, pypto.DT_FP16)

# 输出张量 c 的形状为 (16, 2, 64)