提交 03c6f9bb authored 作者: khaotik's avatar khaotik 提交者: khaotik

add support for undef as override parameter

上级 ece4c2e4
......@@ -9,6 +9,7 @@ from theano.compile.function_module import orig_function
from theano.compile import SharedVariable, rebuild_collect_shared, optdb
from theano.gof import ops_with_inner_function
from theano.gof.graph import io_connection_pattern
from theano.gof.utils import undef
class OpFromGraph(gof.Op):
......@@ -28,13 +29,14 @@ class OpFromGraph(gof.Op):
inline: bool, optional
if True, will cause the Op's original graph being used during
compilation, otherwise will use a pre-compiled function inside.
grad_overrides: None | function | list of (None|function), optional
grad_overrides: None | undef | function | list of (None|undef|function), optional
Used to override default gradient routine.
Overriding function(s) must take two list of variable(s) as inputs,
the original inputs and ups gradients
For different `grad_overrides`:
- `None` : will use default gradient routine.
- theano.utils.undef : No gradient will be used (zero)
- function : must return list of Variable.
- list : each function must return a single Variable. The order
of the list must corresponds to inputs
......@@ -196,9 +198,19 @@ class OpFromGraph(gof.Op):
wrt=wrt_l,
disconnected_inputs='ignore') if wrt_l else [])
# combine overriding gradients
all_grads_l = []
for inp, gov in izip(self.local_inputs, goverrides_l):
if gov is None:
all_grads_l.append(next(gdefaults))
elif gov is undef:
all_grads_l.append(
inp.zeros_like().astype(theano.config.floatX))
else:
all_grads_l.append(gov(self.local_inputs, output_grads))
elif self._grad_op is undef:
all_grads_l = [
gov(self.local_inputs, output_grads) if gov
else next(gdefaults) for gov in goverrides_l]
inp.zeros_like().astype(theano.config.floatX)
for inp in self.local_inputs]
else:
all_grads_l = self._grad_op(self.local_inputs, output_grads)
self._grad_op = type(self)(
......@@ -227,6 +239,8 @@ class OpFromGraph(gof.Op):
def grad(self, inputs, output_grads):
if not self._grad_op_is_cached:
self._recompute_grad_op()
if self._grad_op is undef:
return [None for _ in self.input_types]
return self._grad_op(*(list(inputs) + list(output_grads)))
def make_node(self, *inputs):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论