pypto.from_torch#

产品支持情况#

产品

是否支持

Atlas A3 训练系列产品/Atlas A3 推理系列产品

Atlas A2 训练系列产品/Atlas A2 推理系列产品

功能说明#

将一个torch.Tensor转换为pypto.Tensor。可显式指定转换后的pypto.Tensor的名称。可将转换后的pypto.Tensor的指定维度标记为动态维度,用于表示该维度在后续编译/运行阶段可变。

函数原型#

from_torch(tensor: torch.Tensor, name: str="", *, dynamic_axis: Optional[List[int]] = None,
           tensor_format: Optional[TileOpFormat] = None, dtype: Optional[DataType] = None) -> pypto.Tensor

参数说明#

参数名

输入/输出

说明

tensor

输入

需要转换为pypto.Tensor的torch.Tensor对象。

name

输入

pypto.Tensor的名称。默认为空字符串,表示由from_torch自动为其命名。

dynamic_axis

输入

要标记为动态的维度索引列表。默认为None,表示不标记任何维度。

tensor_format

输入

要指定的pypto.TileOpFormat格式。为None时根据Tensor NPU Fromat 自动推导。

dtype

输入

要指定的pypto.DataType类型。为None时根据torch.Tensor的dtype自动推导。

返回值说明#

返回转换后的pypto.Tensor。

约束说明#

  • 入参tensor类型必须为torch.Tensor或其子类。

  • 入参tensor在指定内存格式的顺序下是连续的(tensor.is_contiguous() == True)。

  • 入参tensor支持如下数据类型(dtype):

    • torch.float16

    • torch.bfloat16

    • torch.float32

    • torch.float64

    • torch.int8

    • torch.uint8

    • torch.int16

    • torch.uint16

    • torch.int32

    • torch.uint32

    • torch.int64

    • torch.uint64

    • torch.bool

调用示例#

x= torch.randn(2, 3)
x_pto = pypto.from_torch(x)
print(x_pto.shape)
y = torch.randn(2, 3)
y_pto = pypto.from_torch(y, "y", dynamic_axis=[0])
print(y_pto.shape)
z = torch.randn(2, 3)
z_pto = pypto.from_torch(z, "z", tensor_format=pypto.TileOpFormat.TILEOP_NZ)
print(z_pto.format)
k = torch.randn(2, 3)
k_pto = pypto.from_torch(k, "k", dtype=pypto.DataType.DT_HF8)
print(k_pto.dtype)

结果示例如下:

[2, 3]
[SymbolicScalar(RUNTIME_GetInputShapeDim(ARG_input_tensor,0)), 3]
TileOpFormat.TILEOP_NZ
DataType.DT_HF8