pypto.full#

产品支持情况#

产品

是否支持

Ascend 950PR/Ascend 950DT

Atlas A3 训练系列产品/Atlas A3 推理系列产品

Atlas A2 训练系列产品/Atlas A2 推理系列产品

功能说明#

创建一个大小为 size、填充值为 fill_value 的Tensor。其数据类型为dtype。

函数原型#

full(size: List[int], fill_value: Union[int, float, Element], dtype: DataType, *, valid_shape: Optional[Union[List[int], List[SymbolicScalar]]] = None ) -> Tensor

参数说明#

参数名

输入/输出

说明

size

输入

源操作数,用于定义输出Tensor的Shape。
支持的数据类型为:List[int]。

fill_value

输入

源操作数,用于填充输出Tensor的值。
支持的数据类型为:int, float, Element。
当为 int 或者 float 类型时会自动转换为 Element 类型,其中 int 对应 DT_INT_32,float 对应 DT_FP32。当需要使用其他数据类型时,可以通过 Element 构建。
Element 支持的数据类型为:DT_FP32,DT_INT32,DT_INT16,DT_FP16,DT_BF16。
输入需要和 dtype 类型相同,不支持隐式转化。

dtype

输入

源操作数,用于定义输出Tensor的类型。
支持的数据类型为:DT_FP32,DT_INT32,DT_INT16,DT_FP16,DT_BF16。
输入需要和 fill_value 类型相同,不支持隐式转化。

valid_shape

输入

源操作数,用于定义输出Tensor的动态Shape,关键字参数,用于动态图,静态图可以省略。
支持的类型为 List[SymbolicScalar], List[int]。

返回值说明#

返回输出Tensor,Tensor的数据类型和dtype相同,Shape为size大小,全部的值为fill_value。

约束说明#

  1. valid_shape 用于动态图场景。

    在动态图场景中,若需生成 [5,5] 的 Tensor并设置 ViewShape 为 [2,2],框架会通过 pypto.loop 循环生成 [2,2] 分块,并按偏移量拼接。此时若未传入 valid_shape,代码将默认生成全 [2,2] 的Tensor(如 pypto.full([2,2], 1, pypto.DT_INT32))。

    然而,当总尺寸 [5,5] 无法被分块尺寸 [2,2] 整除时,尾块的有效形状(如 [1,1])无法由框架自动推导。例如,最后一行/列可能仅包含 1 个元素,而非完整的 [2,2] 分块。此时必须通过 valid_shape 明确指定尾块的实际有效形状,如下:

    pypto.full([2, 2], 1, pypto.DT_INT32, valid_shape=[pypto.min(2, 5 - 2 * b_idx), pypto.min(2, 5 - 2 * s_idx)]), 其中b_idx 和 s_idx 表示循环索引。

  2. tileshape的维度与result 维度相同,用于切分 result。

调用示例#

TileShape设置示例#

调用该operation接口前,应通过set_vec_tile_shapes设置TileShape。

TileShape维度应和输出一致。

如输入size为[m, n],输出为[m, n],TileShape设置为[m1, n1],则m1, n1分别用于切分m, n轴。

pypto.set_vec_tile_shapes(4, 16)

接口调用示例#

# Valid shapes use keyword argument
x1 = 1.0 # must be 1.0; implicit conversion is not supported
y1 = pypto.full([2,2], x1, pypto.DT_FP32, valid_shape = [pypto.symbolic_scalar(2), pypto.symbolic_scalar(2)])

x2 = pypto.Element(pypto.DT_INT32,1)
y2 = pypto.full([2,2], x2, pypto.DT_INT32, valid_shape = [pypto.symbolic_scalar(2), pypto.symbolic_scalar(2)])

# In static graphs, validshape can be ignored
x3 = pypto.Element(pypto.DT_INT32,1)
y3 = pypto.full([2,2], x3, pypto.DT_INT32)

结果示例如下:

y1输出数据: [[1.0,1.0], [1.0,1.0]]
y2输出数据: [[1,1],[1,1]]
y3输出数据: [[1,1], [1,1]