提交 a51c6ec4 authored 作者: ChienliMa's avatar ChienliMa

draft of infer_shape

上级 64eb0117
......@@ -4,8 +4,12 @@ from theano.compat import izip
from theano.compile.function_module import orig_function
from theano.compile import SharedVariable, rebuild_collect_shared
from theano.gof import ops_with_inner_function
<<<<<<< HEAD
from theano.gof.graph import io_connection_pattern
=======
from theano.gof import graph, FunctionGraph
>>>>>>> draft of infer_shape
class OpFromGraph(gof.Op):
"""This creates an `Op` from inputs and outputs lists of variables.
......@@ -141,6 +145,17 @@ class OpFromGraph(gof.Op):
"""
return io_connection_pattern(self.new_inputs, self.new_outputs)
def infer_shape(self, node, shapes):
# Construct a new fgraph
equiv = graph.clone_get_equiv(self.new_inputs, self.new_outputs)
replacement = dict((equiv[var], var) for var in self.new_inputs)
out_cpy = theano.clone([equiv[var] for var in self.new_outputs], replace=replacement)
fg = FunctionGraph(self.new_inputs, out_cpy)
in_shapes = dict([(v, shp) for v, shp in zip(fg.inputs, shapes)])
all_shapes = theano.tensor.utils.shape_of_variables(fg, in_shapes)
return [all_shapes[var] for var in fg.outputs]
def grad(self, inputs, output_grads):
# OpFromGraph doesn't implement a connection_pattern, so for
# now we regard all inputs and outputs as connected. This will
......
......@@ -151,6 +151,13 @@ class T_OpFromGraph(unittest.TestCase):
[True, False, True]]
assert results == expect_result
def test_infer_shape(self):
x = T.matrix('x')
y = x+x
op_graph = OpFromGraph([x], [y], mode='FAST_RUN')
shapes = op_graph.infer_shape(None, [(5, 5)])
assert shapes[0] == (5, 5)
if __name__ == '__main__':
unittest.main()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论