提交 bd25fb21 authored 作者: Frederic's avatar Frederic

Only change the index that would introduce extra nodes.

上级 6cc267d5
...@@ -968,19 +968,32 @@ class ShapeFeature(object): ...@@ -968,19 +968,32 @@ class ShapeFeature(object):
isinstance(r.owner.op, Shape_i) and isinstance(r.owner.op, Shape_i) and
r.owner.inputs[0] not in var.fgraph.variables): r.owner.inputs[0] not in var.fgraph.variables):
assert var.owner assert var.owner
inp = var node = var.owner
node = inp.owner
# TODO recur on inputs # TODO recur on inputs
# Need to time this to don't have it too slow. # Need to time this to don't have it too slow.
# Make sure to handle the case of (shape_i(x)+1) # Make sure to handle the case of (shape_i(x)+1)
# for v in node.inputs: # see https://github.com/Theano/Theano/issues/3560
# for idx in range(v.ndim):
# self.get_shape(v, idx)
o_shapes = self.get_node_infer_shape(node) o_shapes = self.get_node_infer_shape(node)
assert len(o_shapes) == len(node.outputs) assert len(o_shapes) == len(node.outputs)
for shps, out in zip(o_shapes, node.outputs):
self.set_shape(out, shps, override=True) # Only change the variables and dimensions that would introduce
r = o_shapes[node.outputs.index(inp)][r.owner.op.i] # extra computation
for new_shps, out in zip(o_shapes, node.outputs):
if not hasattr(out, 'ndim'):
continue
merged_shps = list(self.shape_of[out])
changed = False
for i in range(out.ndim):
n_r = merged_shps[i]
if (n_r.owner and
isinstance(n_r.owner.op, Shape_i) and
n_r.owner.inputs[0] not in var.fgraph.variables):
changed = True
merged_shps[i] = new_shps[i]
if changed:
self.set_shape(out, merged_shps, override=True)
r = self.shape_of[var][idx]
return r return r
def shape_ir(self, i, r): def shape_ir(self, i, r):
...@@ -1085,6 +1098,9 @@ class ShapeFeature(object): ...@@ -1085,6 +1098,9 @@ class ShapeFeature(object):
---------- ----------
r : a variable r : a variable
s : None or a tuple of symbolic integers s : None or a tuple of symbolic integers
override : If False, it mean r is a new object in the fgraph.
If True, it mean r is already in the fgraph and we want to
override its shape.
""" """
if not override: if not override:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论