提交 7f84b1ad authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Change minimum/maximum grad to raise NotImplementedError for complex

instead of AssertionError, so it gets correctly detected by new test.
上级 8e12c44a
......@@ -1244,8 +1244,10 @@ class Maximum(BinaryScalarOp):
'((%(x)s)>=(%(y)s)? (%(x)s): nan("")));' % locals())
def grad(self, (x, y), (gz, )):
assert gz.type not in complex_types
# max is not defined for complex_types
if gz.type in complex_types:
# max is currently defined for complex_types,
# but the gradient for complex is not.
raise NotImplementedError()
output = self(x, y)
......@@ -1275,8 +1277,10 @@ class Minimum(BinaryScalarOp):
'((%(x)s)<=(%(y)s)? (%(x)s): nan("")));' % locals())
def grad(self, (x, y), (gz, )):
assert gz.type not in complex_types
# max is not defined for complex_types
if gz.type in complex_types:
# min is currently defined for complex_types,
# but the gradient for complex is not.
raise NotImplementedError()
output = minimum(x, y)
if output.type in discrete_types:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论