pypto.experimental.transposed_batchmatmul#
产品支持情况#
产品 |
是否支持 |
|---|---|
Ascend 950PR/Ascend 950DT |
√ |
Atlas A3 训练系列产品/Atlas A3 推理系列产品 |
√ |
Atlas A2 训练系列产品/Atlas A2 推理系列产品 |
√ |
功能说明#
该接口为定制接口,约束较多。不保证稳定性。
该算子执行转置批矩阵乘法。具体操作为:
将输入张量
tensor_a从形状 (M, B, K) 转置为 (B, M, K)。执行批矩阵乘法,将转置后的
tensor_a(B, M, K) 与tensor_b(B, K, N) 相乘,得到中间结果 (B, M, N)。将中间结果转置回形状 (M, B, N) 作为最终输出。
函数原型#
transposed_batchmatmul(tensor_a: Tensor, tensor_b: Tensor, out_dtype: dtype) -> Tensor
参数说明#
参数名 |
输入/输出 |
说明 |
|---|---|---|
tensor_a |
输入 |
左侧输入张量。 |
tensor_b |
输入 |
右侧输入张量。 |
out_dtype |
输入 |
输出张量的数据类型。 |
返回值说明#
返回输出 Tensor,Tensor 的数据类型由 out_dtype 指定,形状为 (M, B, N)。
调用示例#
import pypto
# 创建输入张量
a = pypto.tensor((16, 2, 32), pypto.DT_FP16, "tensor_a")
b = pypto.tensor((2, 32, 64), pypto.DT_FP16, "tensor_b")
# 调用算子
c = pypto.experimental.transposed_batchmatmul(a, b, pypto.DT_FP16)
# 输出张量 c 的形状为 (16, 2, 64)