提交 15ad6f31 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #1549 from lamblin/fix_shape_cycle

[WIP] Fix shape cycle
...@@ -898,22 +898,29 @@ class ShapeFeature(object): ...@@ -898,22 +898,29 @@ class ShapeFeature(object):
r_shape = self.shape_of[r] r_shape = self.shape_of[r]
else: else:
# If no info is known on r's shape, use other_shape # If no info is known on r's shape, use other_shape
self.shape_of[r] = other_shape self.set_shape(r, other_shape)
for sv in other_shape:
self.shape_of_reverse_index.setdefault(sv, set()).add(r)
return return
# Merge other_shape with r_shape, giving the priority to other_shape # Merge other_shape with r_shape, giving the priority to other_shape
merged_shape = [] merged_shape = []
for i, ps in enumerate(other_shape): for i, ps in enumerate(other_shape):
# If other_shape[i] is uninformative, use r_shape[i].
# For now, we consider 2 cases of uninformative other_shape[i]:
# - Shape_i(i)(other_r);
# - Shape_i(i)(r).
if (ps.owner if (ps.owner
and isinstance(getattr(ps.owner, 'op', None), Shape_i) and isinstance(getattr(ps.owner, 'op', None), Shape_i)
and ps.owner.op.i == i and ps.owner.op.i == i
and ps.owner.inputs[0] in (r, other_r)): and ps.owner.inputs[0] in (r, other_r)):
# If other_shape[i] is uninformative, use r_shape[i].
# For now, we consider 2 cases of uninformative other_shape[i]:
# - Shape_i(i)(other_r);
# - Shape_i(i)(r).
merged_shape.append(r_shape[i])
elif r_shape[i] in theano.gof.graph.ancestors([other_shape[i]]):
# Another case where we want to use r_shape[i] is when
# other_shape[i] actually depends on r_shape[i]. In that case,
# we do not want to substitute an expression with another that
# is strictly more complex. Such a substitution could also lead
# to cycles: if (in the future) r_shape[i] gets replaced by an
# expression of other_shape[i], other_shape[i] may end up
# depending on itself.
merged_shape.append(r_shape[i]) merged_shape.append(r_shape[i])
else: else:
merged_shape.append(other_shape[i]) merged_shape.append(other_shape[i])
...@@ -1107,6 +1114,12 @@ class ShapeFeature(object): ...@@ -1107,6 +1114,12 @@ class ShapeFeature(object):
# replacement. # replacement.
continue continue
if shpnode.outputs[0] in theano.gof.graph.ancestors([repl]):
raise AssertionError(
"This substitution would insert a cycle in the graph:"
"node: %s, i: %i, r: %s, new_r: %s"
% (node, i, r, new_r))
self.scheduled[shpnode] = new_r self.scheduled[shpnode] = new_r
# In case 2, if r is a variable that we've scheduled for shape update, # In case 2, if r is a variable that we've scheduled for shape update,
# then we should cancel it. # then we should cancel it.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论