提交 9365cdec authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Ignore unused inputs in some legitimate cases.

上级 2720aeed
...@@ -59,9 +59,12 @@ class OpFromGraph(gof.Op): ...@@ -59,9 +59,12 @@ class OpFromGraph(gof.Op):
if g is None: if g is None:
self.grad_ops.append(lambda *args: None) self.grad_ops.append(lambda *args: None)
else: else:
# It is normal if some inputs are not needed in order
# to compute the gradient, so we ignore them.
self.grad_ops.append(OpFromGraph(inputs + output_grads, self.grad_ops.append(OpFromGraph(inputs + output_grads,
[g], [g],
grad_depth = grad_depth - 1)) grad_depth = grad_depth - 1,
on_unused_input='ignore'))
def __eq__(self, other): def __eq__(self, other):
#TODO: recognize a copy #TODO: recognize a copy
return self is other return self is other
......
...@@ -473,9 +473,9 @@ class Method(Component): ...@@ -473,9 +473,9 @@ class Method(Component):
else: else:
effective_mode = self.mode effective_mode = self.mode
#backport # We ignore unused inputs, since all the inputs are passed
#effective_mode = mode if self.mode is None else self.mode rval = F.orig_function(inputs, outputs, effective_mode,
rval = F.orig_function(inputs, outputs, effective_mode) on_unused_input='ignore')
memo[self] = rval memo[self] = rval
return rval return rval
......
...@@ -765,10 +765,11 @@ def verify_grad(fun, pt, n_tests=2, rng=None, eps=None, abs_tol=None, ...@@ -765,10 +765,11 @@ def verify_grad(fun, pt, n_tests=2, rng=None, eps=None, abs_tol=None,
def function(inputs, output): def function(inputs, output):
if mode is None: if mode is None:
f = compile.function(inputs, output, accept_inplace=True, f = compile.function(inputs, output, accept_inplace=True,
allow_input_downcast=True) allow_input_downcast=True, on_unused_input='ignore')
else: else:
f = compile.function(inputs, output, accept_inplace=True, f = compile.function(inputs, output, accept_inplace=True,
allow_input_downcast=True, mode=mode) allow_input_downcast=True, mode=mode,
on_unused_input='ignore')
return f return f
tensor_pt = [TensorType( tensor_pt = [TensorType(
......
...@@ -793,7 +793,8 @@ def scan(fn, ...@@ -793,7 +793,8 @@ def scan(fn,
dummy_outs, dummy_outs,
updates=updates, updates=updates,
mode=compile.mode.Mode(linker='py', mode=compile.mode.Mode(linker='py',
optimizer=None)) optimizer=None),
on_unused_input='ignore')
## ##
### Step 5. Re-arange inputs of scan into a more strict order ### Step 5. Re-arange inputs of scan into a more strict order
......
...@@ -480,7 +480,8 @@ class Scan(PureOp): ...@@ -480,7 +480,8 @@ class Scan(PureOp):
wrapped_outputs, wrapped_outputs,
mode=self.mode_instance, mode=self.mode_instance,
name=self.name, name=self.name,
profile=profile) profile=profile,
on_unused_input='ignore')
try: try:
cython_mintaps = numpy.asarray(self.mintaps, dtype='int32') cython_mintaps = numpy.asarray(self.mintaps, dtype='int32')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论