Unverified 提交 9cccd347 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #6618 from abergeron/opfromgraph_disconnected

Add option to override connection_pattern for OpFromGraph
......@@ -105,6 +105,11 @@ class OpFromGraph(gof.Op):
:class:`Variable <theano.gof.Variable>`. Each list element corresponds
to a specific output of R_op, length of list must be equal to number of outputs.
connection_pattern : list of list
If not ``None``, this will be used as the connection_pattern
for this op.
name : string, optional
A name for debugging purposes
......@@ -248,6 +253,7 @@ class OpFromGraph(gof.Op):
lop_overrides='default',
grad_overrides='default',
rop_overrides='default',
connection_pattern=None,
name=None, **kwargs
):
if not isinstance(outputs, list):
......@@ -298,6 +304,8 @@ class OpFromGraph(gof.Op):
self._lop_type = 'lop'
self.set_rop_overrides(rop_overrides)
self._connection_pattern = connection_pattern
if name is not None:
assert isinstance(name, str), 'name must be None or string object'
self.name = name
......@@ -637,6 +645,9 @@ class OpFromGraph(gof.Op):
Return connection pattern of subfgraph defined by inputs and outputs.
"""
if self._connection_pattern is not None:
return self._connection_pattern
inp_len = len(self.local_inputs)
out_len = len(self.local_outputs)
cpmat_self = io_connection_pattern(
......
......@@ -266,6 +266,37 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
# TODO list override case
@test_params
def test_connection_pattern_override(self, cls_ofg):
x, y = T.vectors('xy')
def f1(x, y):
del x
# but we know how to backpropagate for x for some reasons
# and we don't care about the gradient wrt y.
return y + T.round(y)
def f1_back(inputs, output_gradients):
return [
output_gradients[0],
theano.gradient.disconnected_type()]
op = cls_ofg(
inputs=[x, y],
outputs=[f1(x, y)],
grad_overrides=f1_back,
connection_pattern=[[True], [False]], # This is new
on_unused_input='ignore') # This is new
c = op(x, y)
g1 = theano.grad(c.sum(), x)
out = g1.eval({
x: np.ones((5,), dtype=np.float32),
y: np.ones((5,), dtype=np.float32)})
assert np.allclose(out, [1.] * 5)
@test_params
def test_nested(self, cls_ofg):
x, y = T.vectors('xy')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论