提交 bd25fb21 authored 作者: Frederic's avatar Frederic

Only change the index that would introduce extra nodes.

上级 6cc267d5
......@@ -968,19 +968,32 @@ class ShapeFeature(object):
isinstance(r.owner.op, Shape_i) and
r.owner.inputs[0] not in var.fgraph.variables):
assert var.owner
inp = var
node = inp.owner
node = var.owner
# TODO recur on inputs
# Need to time this to don't have it too slow.
# Make sure to handle the case of (shape_i(x)+1)
# for v in node.inputs:
# for idx in range(v.ndim):
# self.get_shape(v, idx)
# see https://github.com/Theano/Theano/issues/3560
o_shapes = self.get_node_infer_shape(node)
assert len(o_shapes) == len(node.outputs)
for shps, out in zip(o_shapes, node.outputs):
self.set_shape(out, shps, override=True)
r = o_shapes[node.outputs.index(inp)][r.owner.op.i]
# Only change the variables and dimensions that would introduce
# extra computation
for new_shps, out in zip(o_shapes, node.outputs):
if not hasattr(out, 'ndim'):
continue
merged_shps = list(self.shape_of[out])
changed = False
for i in range(out.ndim):
n_r = merged_shps[i]
if (n_r.owner and
isinstance(n_r.owner.op, Shape_i) and
n_r.owner.inputs[0] not in var.fgraph.variables):
changed = True
merged_shps[i] = new_shps[i]
if changed:
self.set_shape(out, merged_shps, override=True)
r = self.shape_of[var][idx]
return r
def shape_ir(self, i, r):
......@@ -1085,6 +1098,9 @@ class ShapeFeature(object):
----------
r : a variable
s : None or a tuple of symbolic integers
override : If False, it mean r is a new object in the fgraph.
If True, it mean r is already in the fgraph and we want to
override its shape.
"""
if not override:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论