提交 1396175a authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add option to override connection_pattern for OpFromGraph.

上级 9212db21
...@@ -105,6 +105,11 @@ class OpFromGraph(gof.Op): ...@@ -105,6 +105,11 @@ class OpFromGraph(gof.Op):
:class:`Variable <theano.gof.Variable>`. Each list element corresponds :class:`Variable <theano.gof.Variable>`. Each list element corresponds
to a specific output of R_op, length of list must be equal to number of outputs. to a specific output of R_op, length of list must be equal to number of outputs.
connection_pattern : list of list
If not ``None``, this will be used as the connection_pattern
for this op.
name : string, optional name : string, optional
A name for debugging purposes A name for debugging purposes
...@@ -248,6 +253,7 @@ class OpFromGraph(gof.Op): ...@@ -248,6 +253,7 @@ class OpFromGraph(gof.Op):
lop_overrides='default', lop_overrides='default',
grad_overrides='default', grad_overrides='default',
rop_overrides='default', rop_overrides='default',
connection_pattern=None,
name=None, **kwargs name=None, **kwargs
): ):
if not isinstance(outputs, list): if not isinstance(outputs, list):
...@@ -298,6 +304,8 @@ class OpFromGraph(gof.Op): ...@@ -298,6 +304,8 @@ class OpFromGraph(gof.Op):
self._lop_type = 'lop' self._lop_type = 'lop'
self.set_rop_overrides(rop_overrides) self.set_rop_overrides(rop_overrides)
self._connection_pattern = connection_pattern
if name is not None: if name is not None:
assert isinstance(name, str), 'name must be None or string object' assert isinstance(name, str), 'name must be None or string object'
self.name = name self.name = name
...@@ -637,6 +645,9 @@ class OpFromGraph(gof.Op): ...@@ -637,6 +645,9 @@ class OpFromGraph(gof.Op):
Return connection pattern of subfgraph defined by inputs and outputs. Return connection pattern of subfgraph defined by inputs and outputs.
""" """
if self._connection_pattern is not None:
return self._connection_pattern
inp_len = len(self.local_inputs) inp_len = len(self.local_inputs)
out_len = len(self.local_outputs) out_len = len(self.local_outputs)
cpmat_self = io_connection_pattern( cpmat_self = io_connection_pattern(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论