Fixed backporting bug (conditional expression not parse correctly in my head)

上级 d75b1c78
...@@ -733,15 +733,16 @@ class Clip(ScalarOp): ...@@ -733,15 +733,16 @@ class Clip(ScalarOp):
else: else:
return x return x
#backport
#return min if x < min else max if x > max else x #return min if x < min else max if x > max else x
def c_code(self, node, name, (x, min, max), (z, ), sub): def c_code(self, node, name, (x, min, max), (z, ), sub):
return "%(z)s = %(x)s < %(min)s ? %(min)s : %(x)s > %(max)s ? %(max)s : %(x)s;" % locals() return "%(z)s = %(x)s < %(min)s ? %(min)s : %(x)s > %(max)s ? %(max)s : %(x)s;" % locals()
def grad(self, (x, min, max), (gz, )): def grad(self, (x, min, max), (gz, )):
gx = ((x > min) & (x < max)) * gz gx = ((x > min) & (x < max)) * gz
if x.type in grad_types: if x.type in grad_types:
return gx return gx, None, None
else: else:
return None,None,None return None, None, None
#return gx if x.type in grad_types else None, None, None #return gx if x.type in grad_types else None, None, None
clip = Clip(transfer_type(0), name = 'clip') clip = Clip(transfer_type(0), name = 'clip')
...@@ -753,7 +754,7 @@ class First(BinaryScalarOp): ...@@ -753,7 +754,7 @@ class First(BinaryScalarOp):
return "%(z)s = %(x)s;" % locals() return "%(z)s = %(x)s;" % locals()
def grad(self, (x, y), (gz, )): def grad(self, (x, y), (gz, )):
if x.type in grad_types: if x.type in grad_types:
return gz return gz, None
else: else:
return None,None return None,None
#backport #backport
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论