提交 664d0a96 authored 作者: ChienliMa's avatar ChienliMa

delete clone before and after infer_shape

上级 ed0e5b26
......@@ -142,16 +142,10 @@ class OpFromGraph(gof.Op):
return io_connection_pattern(self.new_inputs, self.new_outputs)
def infer_shape(self, node, shapes):
# clone fgraph
equiv = clone_get_equiv(self.new_inputs, self.new_outputs)
in_v = [equiv[var] for var in self.new_inputs]
out_v = [equiv[var] for var in self.new_outputs]
shape = theano.scan_module.scan_utils.infer_shape(out_v, in_v,
shape = theano.scan_module.scan_utils.infer_shape(self.new_outputs,
self.new_inputs,
shapes)
replacement = dict([(ori, rpl) for ori, rpl in izip(in_v, node.inputs)])
return [theano.clone(shape[0], replace=replacement)]
return shape
def grad(self, inputs, output_grads):
# OpFromGraph doesn't implement a connection_pattern, so for
......@@ -185,3 +179,5 @@ class OpFromGraph(gof.Op):
# Since OpFromGraph contains a Theano compiled function, we should let
# DebugMode know about it
ops_with_inner_function[OpFromGraph] = 'fn'
......@@ -12,6 +12,7 @@ from theano.compile.builders import OpFromGraph
from theano.tests import unittest_tools
import unittest
class T_OpFromGraph(unittest_tools.InferShapeTester):
......@@ -155,8 +156,15 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
def test_infer_shape(self):
x = T.matrix('x')
y = x+x
op_graph = OpFromGraph([x], [y])
self._compile_and_check([x],
[op_graph(x)],
z = x*x
op_graph = OpFromGraph([x], [y,z])
q = T.matrix('q')
self._compile_and_check([q],
[op_graph(q)[0],op_graph(q)[1]],
[numpy.ones([3,4], dtype=config.floatX)],
OpFromGraph)
if __name__ == '__main__':
unittest.main()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论