条件与分支#
条件和分支用于在程序中实现条件判断,从而根据不同的条件执行不同的代码逻辑。编程框架支持两类条件与分支功能:
静态条件分支:在编译时配置条件分支,生成执行时固定的指令,可通过多个jit生成不同的kernel。
动态条件分支:在运行时判断条件及分支,执行对应的功能。
静态条件分支#
# 使用入参 add1_flag=False 生成 kernel
@pypto.frontend.jit
def add_kernel_false(
input0: pypto.Tensor([pypto.DYNAMIC, 4, 1, 64], pypto.DT_FP32),
input1: pypto.Tensor([pypto.DYNAMIC, 4, 1, 64], pypto.DT_FP32),
output: pypto.Tensor([pypto.DYNAMIC, 4, 1, 64], pypto.DT_FP32),
val: int
):
add_core(input0, input1, output, val, False)
#使用入参add1_flag=True生成kernel
@pypto.frontend.jit
def add_kernel_true(
input0: pypto.Tensor([pypto.DYNAMIC, 4, 1, 64], pypto.DT_FP32),
input1: pypto.Tensor([pypto.DYNAMIC, 4, 1, 64], pypto.DT_FP32),
output: pypto.Tensor([pypto.DYNAMIC, 4, 1, 64], pypto.DT_FP32),
val: int
):
add_core(input0, input1, output, val, True)
代码示例:
def add_core(input0: pypto.Tensor, input1: pypto.Tensor, output: pypto.Tensor, val: int, add1_flag: bool = False):
# Tiling 配置与循环逻辑
pypto.set_vec_tile_shapes(1, 4, 1, 64)
#calculate the loop parameters
b = input0.shape[0]
tile_b = 1
b_loop = b // tile_b
for idx in pypto.loop(b_loop):
b_offset = idx * tile_b
b_offset_end = (idx + 1) * tile_b
t0_sub = input0[b_offset:b_offset_end, ...]
t1_sub = input1[b_offset:b_offset_end, ...]
t3_sub = t0_sub + t1_sub
if add1_flag:
output[b_offset:b_offset_end, ...] = t3_sub + val
else:
output[b_offset:b_offset_end, ...] = t3_sub
该用例在add_kernel函数中增加了一个可选参数add1_flag,并使用该参数进行不同的处理。如果add1_flag为True,则在输出结果上加1;反之,则直接输出前一个处理的结果。
完整样例请参考:condition.py。
动态条件分支#
运行时判断条件及分支,执行对应的功能。核心接口包括:
pypto.cond(condition):运行时判断条件。pypto.is_loop_begin(idx):判断是否为循环首个迭代。pypto.is_loop_end(idx):判断是否为循环最后一个迭代。
@pypto.frontend.jit
def add_kernel(
input0: pypto.Tensor([pypto.DYNAMIC, 4, 1, 64], pypto.DT_FP32),
input1: pypto.Tensor([pypto.DYNAMIC, 4, 1, 64], pypto.DT_FP32),
output: pypto.Tensor([pypto.DYNAMIC, 4, 1, 64], pypto.DT_FP32),
val: int
):
...
for idx in pypto.loop(b_loop):
t3_sub = t0_sub + t1_sub
if idx < 2: # 动态条件判断
output[b_offset:b_offset_end, ...] = t3_sub + val
else:
output[b_offset:b_offset_end, ...] = t3_sub
# 或者基于循环位置的条件
if pypto.is_loop_begin(idx):
output[b_offset:b_offset_end, ...] = t3_sub + val
elif pypto.is_loop_end(idx):
output[b_offset:b_offset_end, ...] = t3_sub + val + 1
else:
output[b_offset:b_offset_end, ...] = t3_sub
完整样例请参考: