提交 337f6c66 authored 作者: Frederic's avatar Frederic

(This was started with Raul Chandias Ferrari)

Flatten Composite. This help gh-689 as it allow it to support up to 57 iteration of the loop instead of just 43. This could also help Theano cache in the case where graph could have different fusion order, so different way the composite are in each other, but do the same computation. This make the printing of Composite more readable.
上级 b2ba9153
...@@ -2974,7 +2974,39 @@ class Composite(ScalarOp): ...@@ -2974,7 +2974,39 @@ class Composite(ScalarOp):
# We need to clone the graph as sometimes its nodes already # We need to clone the graph as sometimes its nodes already
# contain a reference to an fgraph. As we want the Composite # contain a reference to an fgraph. As we want the Composite
# to be pickable, we can't have reference to fgraph. # to be pickable, we can't have reference to fgraph.
inputs, outputs = gof.graph.clone(inputs, outputs)
# Also, if there is Composite in the inner graph, we want to
# remove them. In that case, we do a more complicated clone
# that will flatten Composite. We don't need to do this
# recusively, as the way the fusion optimizer work, we have
# only 1 new Composite each time at the output.
if not any([isinstance(var.owner.op, Composite) for var in outputs]):
# No inner Composite
inputs, outputs = gof.graph.clone(inputs, outputs)
else:
# Inner Composite that we need to flatten
assert len(outputs) == 1
# 1. Create a new graph from inputs up to the
# Composite
res = theano.compile.rebuild_collect_shared(
inputs=inputs,
outputs=outputs[0].owner.inputs,
copy_inputs_over=False) # Clone also the inputs
# 2. We continue this partial clone with the graph in
# the inner Composite
res2 = theano.compile.rebuild_collect_shared(
inputs=outputs[0].owner.op.inputs,
outputs=outputs[0].owner.op.outputs,
replace=dict(zip(outputs[0].owner.op.inputs, res[1]))
)
assert len(res2[1]) == len(outputs)
assert len(res[0]) == len(inputs)
assert res[0] != inputs
inputs, outputs = res[0], res2[1]
# Next assert comment just for speed
assert not any([isinstance(node.op, Composite) for node in
theano.gof.graph.ops(inputs, outputs)])
self.inputs = copy(inputs) self.inputs = copy(inputs)
self.outputs = copy(outputs) self.outputs = copy(outputs)
self.inputs_type = tuple([input.type for input in inputs]) self.inputs_type = tuple([input.type for input in inputs])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论