提交 64a7b4ea authored 作者: ChienliMa's avatar ChienliMa

Clone only once

上级 4cbfea77
...@@ -143,12 +143,21 @@ class OpFromGraph(gof.Op): ...@@ -143,12 +143,21 @@ class OpFromGraph(gof.Op):
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
out_shp = theano.scan_module.scan_utils.infer_shape(self.new_outputs, out_shp = theano.scan_module.scan_utils.infer_shape(self.new_outputs,
self.new_inputs, self.new_inputs,
shapes) shapes)
replacement = dict([(ori, rpl) for ori, rpl replacement = dict([(ori, rpl) for ori, rpl
in izip(self.new_inputs, node.inputs)]) in izip(self.new_inputs, node.inputs)])
return [theano.clone(shape, replace=replacement) for shape in out_shp] repl = dict(zip(self.new_inputs, node.inputs))
cloned = theano.clone(reduce(tuple.__add__, out_shp), replace=repl)
ret = []
used = 0
for i in range(len(out_shp)):
nb = len(out_shp[i])
ret.append(cloned[used: used + nb])
used += nb
return ret
def grad(self, inputs, output_grads): def grad(self, inputs, output_grads):
# OpFromGraph doesn't implement a connection_pattern, so for # OpFromGraph doesn't implement a connection_pattern, so for
......
...@@ -162,7 +162,7 @@ class T_OpFromGraph(unittest_tools.InferShapeTester): ...@@ -162,7 +162,7 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
q = T.matrix('q') q = T.matrix('q')
p = T.matrix('p') p = T.matrix('p')
self._compile_and_check([q,p], self._compile_and_check([q,p],
[op_graph(q,p)[0],op_graph(q,p)[1]], op_graph(q,p),
[numpy.ones([3,4], dtype=config.floatX), [numpy.ones([3,4], dtype=config.floatX),
numpy.ones([3,4], dtype=config.floatX)], numpy.ones([3,4], dtype=config.floatX)],
OpFromGraph) OpFromGraph)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论