提交 03c6f9bb authored 作者: khaotik's avatar khaotik 提交者: khaotik

add support for undef as override parameter

上级 ece4c2e4
...@@ -9,6 +9,7 @@ from theano.compile.function_module import orig_function ...@@ -9,6 +9,7 @@ from theano.compile.function_module import orig_function
from theano.compile import SharedVariable, rebuild_collect_shared, optdb from theano.compile import SharedVariable, rebuild_collect_shared, optdb
from theano.gof import ops_with_inner_function from theano.gof import ops_with_inner_function
from theano.gof.graph import io_connection_pattern from theano.gof.graph import io_connection_pattern
from theano.gof.utils import undef
class OpFromGraph(gof.Op): class OpFromGraph(gof.Op):
...@@ -28,13 +29,14 @@ class OpFromGraph(gof.Op): ...@@ -28,13 +29,14 @@ class OpFromGraph(gof.Op):
inline: bool, optional inline: bool, optional
if True, will cause the Op's original graph being used during if True, will cause the Op's original graph being used during
compilation, otherwise will use a pre-compiled function inside. compilation, otherwise will use a pre-compiled function inside.
grad_overrides: None | function | list of (None|function), optional grad_overrides: None | undef | function | list of (None|undef|function), optional
Used to override default gradient routine. Used to override default gradient routine.
Overriding function(s) must take two list of variable(s) as inputs, Overriding function(s) must take two list of variable(s) as inputs,
the original inputs and ups gradients the original inputs and ups gradients
For different `grad_overrides`: For different `grad_overrides`:
- `None` : will use default gradient routine. - `None` : will use default gradient routine.
- theano.utils.undef : No gradient will be used (zero)
- function : must return list of Variable. - function : must return list of Variable.
- list : each function must return a single Variable. The order - list : each function must return a single Variable. The order
of the list must corresponds to inputs of the list must corresponds to inputs
...@@ -196,9 +198,19 @@ class OpFromGraph(gof.Op): ...@@ -196,9 +198,19 @@ class OpFromGraph(gof.Op):
wrt=wrt_l, wrt=wrt_l,
disconnected_inputs='ignore') if wrt_l else []) disconnected_inputs='ignore') if wrt_l else [])
# combine overriding gradients # combine overriding gradients
all_grads_l = []
for inp, gov in izip(self.local_inputs, goverrides_l):
if gov is None:
all_grads_l.append(next(gdefaults))
elif gov is undef:
all_grads_l.append(
inp.zeros_like().astype(theano.config.floatX))
else:
all_grads_l.append(gov(self.local_inputs, output_grads))
elif self._grad_op is undef:
all_grads_l = [ all_grads_l = [
gov(self.local_inputs, output_grads) if gov inp.zeros_like().astype(theano.config.floatX)
else next(gdefaults) for gov in goverrides_l] for inp in self.local_inputs]
else: else:
all_grads_l = self._grad_op(self.local_inputs, output_grads) all_grads_l = self._grad_op(self.local_inputs, output_grads)
self._grad_op = type(self)( self._grad_op = type(self)(
...@@ -227,6 +239,8 @@ class OpFromGraph(gof.Op): ...@@ -227,6 +239,8 @@ class OpFromGraph(gof.Op):
def grad(self, inputs, output_grads): def grad(self, inputs, output_grads):
if not self._grad_op_is_cached: if not self._grad_op_is_cached:
self._recompute_grad_op() self._recompute_grad_op()
if self._grad_op is undef:
return [None for _ in self.input_types]
return self._grad_op(*(list(inputs) + list(output_grads))) return self._grad_op(*(list(inputs) + list(output_grads)))
def make_node(self, *inputs): def make_node(self, *inputs):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论