提交 8161bf9b authored 作者: Frederic Bastien's avatar Frederic Bastien

some improvement to Canonizer when their is DimShuffle.

上级 f784c7b5
......@@ -547,6 +547,14 @@ class Canonizer(gof.LocalOptimizer):
# the dtype of the 'input' argument. The leaf-Variables of the graph covered by the
# recursion may be of any Variable type.
if len(input.clients) > 1:
# this logic is too conservative, but doing it is better than not doing it.
#
# we don't want to canonize a subgraph that we will need to compute anyway for the other clients.
# This check is too conservative because if the other clients are also in the subgraph we are canonizing,
# then we should [probably?] recurse anyway.
return [input], []
if input.owner is None or input.owner.op not in [self.main, self.inverse, self.reciprocal]:
if input.owner and isinstance(input.owner.op, T.DimShuffle):
# If input is a DimShuffle of some input which does something like this:
......@@ -778,6 +786,18 @@ class Canonizer(gof.LocalOptimizer):
out = node.outputs[0]
assert len(node.outputs) == 1
# check if any of the clients of this node would be part of this canonized graph...
# if so, we do nothing and wait for them to be transformed.
def _bypass_dimshuffle(n):
if isinstance(n.op, DimShuffle) and len(n.outputs[0].clients) <= 1:
return _bypass_dimshuffle(n.outputs[0].clients.__iter__().next()[0])
else:
return n
for c,c_idx in out.clients:
if c=='output': continue
if _bypass_dimshuffle(c).op in [self.main, self.inverse, self.reciprocal]:
return False
# Here we make the canonical version of the graph around this node
# See the documentation of get_num_denum and simplify
orig_num, orig_denum = self.get_num_denum(node.outputs[0])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论