提交 64a7b4ea authored 作者: ChienliMa's avatar ChienliMa

Clone only once

上级 4cbfea77
......@@ -143,12 +143,21 @@ class OpFromGraph(gof.Op):
def infer_shape(self, node, shapes):
out_shp = theano.scan_module.scan_utils.infer_shape(self.new_outputs,
self.new_inputs,
shapes)
self.new_inputs,
shapes)
replacement = dict([(ori, rpl) for ori, rpl
in izip(self.new_inputs, node.inputs)])
return [theano.clone(shape, replace=replacement) for shape in out_shp]
repl = dict(zip(self.new_inputs, node.inputs))
cloned = theano.clone(reduce(tuple.__add__, out_shp), replace=repl)
ret = []
used = 0
for i in range(len(out_shp)):
nb = len(out_shp[i])
ret.append(cloned[used: used + nb])
used += nb
return ret
def grad(self, inputs, output_grads):
# OpFromGraph doesn't implement a connection_pattern, so for
......
......@@ -162,7 +162,7 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
q = T.matrix('q')
p = T.matrix('p')
self._compile_and_check([q,p],
[op_graph(q,p)[0],op_graph(q,p)[1]],
op_graph(q,p),
[numpy.ones([3,4], dtype=config.floatX),
numpy.ones([3,4], dtype=config.floatX)],
OpFromGraph)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论