提交 a51c6ec4 authored 作者: ChienliMa's avatar ChienliMa

draft of infer_shape

上级 64eb0117
...@@ -4,8 +4,12 @@ from theano.compat import izip ...@@ -4,8 +4,12 @@ from theano.compat import izip
from theano.compile.function_module import orig_function from theano.compile.function_module import orig_function
from theano.compile import SharedVariable, rebuild_collect_shared from theano.compile import SharedVariable, rebuild_collect_shared
from theano.gof import ops_with_inner_function from theano.gof import ops_with_inner_function
<<<<<<< HEAD
from theano.gof.graph import io_connection_pattern from theano.gof.graph import io_connection_pattern
=======
from theano.gof import graph, FunctionGraph
>>>>>>> draft of infer_shape
class OpFromGraph(gof.Op): class OpFromGraph(gof.Op):
"""This creates an `Op` from inputs and outputs lists of variables. """This creates an `Op` from inputs and outputs lists of variables.
...@@ -141,6 +145,17 @@ class OpFromGraph(gof.Op): ...@@ -141,6 +145,17 @@ class OpFromGraph(gof.Op):
""" """
return io_connection_pattern(self.new_inputs, self.new_outputs) return io_connection_pattern(self.new_inputs, self.new_outputs)
def infer_shape(self, node, shapes):
# Construct a new fgraph
equiv = graph.clone_get_equiv(self.new_inputs, self.new_outputs)
replacement = dict((equiv[var], var) for var in self.new_inputs)
out_cpy = theano.clone([equiv[var] for var in self.new_outputs], replace=replacement)
fg = FunctionGraph(self.new_inputs, out_cpy)
in_shapes = dict([(v, shp) for v, shp in zip(fg.inputs, shapes)])
all_shapes = theano.tensor.utils.shape_of_variables(fg, in_shapes)
return [all_shapes[var] for var in fg.outputs]
def grad(self, inputs, output_grads): def grad(self, inputs, output_grads):
# OpFromGraph doesn't implement a connection_pattern, so for # OpFromGraph doesn't implement a connection_pattern, so for
# now we regard all inputs and outputs as connected. This will # now we regard all inputs and outputs as connected. This will
......
...@@ -151,6 +151,13 @@ class T_OpFromGraph(unittest.TestCase): ...@@ -151,6 +151,13 @@ class T_OpFromGraph(unittest.TestCase):
[True, False, True]] [True, False, True]]
assert results == expect_result assert results == expect_result
def test_infer_shape(self):
x = T.matrix('x')
y = x+x
op_graph = OpFromGraph([x], [y], mode='FAST_RUN')
shapes = op_graph.infer_shape(None, [(5, 5)])
assert shapes[0] == (5, 5)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论