pypto.argsort#
产品支持情况#
产品 |
是否支持 |
|---|---|
Ascend 950PR/Ascend 950DT |
√ |
Atlas A3 训练系列产品/Atlas A3 推理系列产品 |
√ |
Atlas A2 训练系列产品/Atlas A2 推理系列产品 |
√ |
功能说明#
获取输入在指定轴按照升序或者降序进行排序后的索引。
函数原型#
argsort(input: Tensor, dim: Optional[int]=None, descending: bool=True) -> Tensor
参数说明#
参数名 |
输入/输出 |
说明 |
|---|---|---|
input |
输入 |
源操作数。 |
dim |
输入 |
指定排序的维度。 |
descending |
输入 |
如果为True,按降序返回索引。如果为False,按升序返回索引。 |
返回值说明#
返回一个Tensor,为输入按照descending在dim轴进行排序后的索引。
约束说明#
当前不支持ViewShape在dim轴切分,即要求ViewShape[dim] = InputShape[dim]。
当前仅支持TileShape在dim轴上为32的倍数,即要求TileShape[dim] % 32 = 0。
当shape较大场景
(tileShape Size/tileShape[dim] * ((viewShape[dim] + 31) / 32 * 32) >= 6KB)时, 排序轴的切分数量小于128。对于四维输入,暂不支持在第0轴上进行排序。
排序时遇到相同值时,采用稳定排序返回对应的索引。
调用示例#
TileShape设置示例#
说明:调用该operation接口前,应通过set_vec_tile_shapes设置TileShape。
TileShape维度应和输入input一致。
如输入intput shape为[m, n, p],dim为2,descending为True,输出为[m, n, p],TileShape设置为[m1, n1, p1],则m1, n1, p1分别用于切分m, n, p轴。
pypto.set_vec_tile_shapes(4, 16, 32)
接口调用示例#
x = pypto.tensor([2, 3], pypto.DT_FP32)
y = pypto.argsort(x, -1, True)
结果示例如下:
输入数据x: [[1.0 2.0 3.0],
[1.0 2.0 3.0]]
输出数据y: [[2, 1, 0],
[2, 1, 0]]