pypto.cond#
产品支持情况#
产品 |
是否支持 |
|---|---|
Atlas A3 训练系列产品/Atlas A3 推理系列产品 |
√ |
Atlas A2 训练系列产品/Atlas A2 推理系列产品 |
√ |
功能说明#
定义一个if 条件操作,实现python中的if功能。
函数原型#
cond(scalar: SymInt)
参数说明#
参数名 |
输入/输出 |
说明 |
|---|---|---|
scalar |
输入 |
条件表达式,可以是整数或SymbolicScalar(符号标量),用于判断条件是否为真 |
返回值说明#
pypto_impl.RecordIfBranch: 返回一个条件分支对象,用于 Python 的 if 语句
约束说明#
必须与 Python 的 if、elif、else 语句配合使用
条件表达式会被记录到计算图中
支持嵌套的条件语句
当函数未使用 @pypto.frontend.jit 或 @pypto.frontend.function 装饰器修饰时,条件表达式需要用 pypto.cond 包装
调用示例#
# 未使用装饰器,需要用 pypto.cond 包装条件表达式
def kernel():
...
for s2_idx in pypto.loop(0, 10, 1, power_of_2(max_unroll_times), name="LOOP_L0_bIdx_mla_prolog", idx_name="b_idx"):
if pypto.cond(pypto.is_loop_end(s2_idx, bn_per_batch)):
...
# 使用装饰器,无需 pypto.cond 包装
@pypto.frontend.jit
def kernel():
...
for s2_idx in pypto.loop(0, 10, 1, power_of_2(max_unroll_times), name="LOOP_L0_bIdx_mla_prolog", idx_name="b_idx"):
if pypto.is_loop_end(s2_idx, bn_per_batch):
...