提交 8e5725ea authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Move condition earlier to avoid useless computations

上级 002872ad
...@@ -427,6 +427,10 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None, ...@@ -427,6 +427,10 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None,
if skip: if skip:
raise SkipTest(skip) raise SkipTest(skip)
if not hasattr(self.op, 'grad'):
# This is not actually an Op
return
for testname, inputs in self.good.items(): for testname, inputs in self.good.items():
inputs = [copy(input) for input in inputs] inputs = [copy(input) for input in inputs]
inputrs = [TensorType( inputrs = [TensorType(
...@@ -456,13 +460,12 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None, ...@@ -456,13 +460,12 @@ def makeTester(name, op, expected, checks=None, good=None, bad_build=None,
var = TensorType(dtype=dtype, broadcastable=bcast)() var = TensorType(dtype=dtype, broadcastable=bcast)()
out_grad_vars.append(var) out_grad_vars.append(var)
if hasattr(self.op, 'grad'): try:
try: in_grad_vars = self.op.grad(inputrs, out_grad_vars)
in_grad_vars = self.op.grad(inputrs, out_grad_vars) except (gof.utils.MethodNotDefined, NotImplementedError):
except (gof.utils.MethodNotDefined, NotImplementedError): pass
pass else:
else: assert None not in in_grad_vars
assert None not in in_grad_vars
Checker.__name__ = name Checker.__name__ = name
return Checker return Checker
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论