Fixed backporting bug (conditional expression not parse correctly in my head)

上级 d75b1c78
......@@ -733,15 +733,16 @@ class Clip(ScalarOp):
else:
return x
#backport
#return min if x < min else max if x > max else x
def c_code(self, node, name, (x, min, max), (z, ), sub):
return "%(z)s = %(x)s < %(min)s ? %(min)s : %(x)s > %(max)s ? %(max)s : %(x)s;" % locals()
def grad(self, (x, min, max), (gz, )):
gx = ((x > min) & (x < max)) * gz
if x.type in grad_types:
return gx
return gx, None, None
else:
return None,None,None
return None, None, None
#return gx if x.type in grad_types else None, None, None
clip = Clip(transfer_type(0), name = 'clip')
......@@ -753,7 +754,7 @@ class First(BinaryScalarOp):
return "%(z)s = %(x)s;" % locals()
def grad(self, (x, y), (gz, )):
if x.type in grad_types:
return gz
return gz, None
else:
return None,None
#backport
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论