提交 23573a95 authored 作者: Frederic's avatar Frederic

Make local_elemwise_alloc don't re-introduce old node in the graph.

上级 ec7681c1
......@@ -955,6 +955,34 @@ class ShapeFeature(object):
return o_shapes
def get_shape(self, var, idx):
""" Optimization can call this to get the current shape_i
It is better to call this then use directly shape_of[var][idx]
as this method should update shape_of if needed.
TODO: Up to now, we don't update it in all cases. Update in all cases.
"""
r = self.shape_of[var][idx]
if (r.owner and
isinstance(r.owner.op, Shape_i) and
r.owner.inputs[0] not in var.fgraph.variables):
assert var.owner
inp = var
node = inp.owner
# TODO recur on inputs
# Need to time this to don't have it too slow.
# Make sure to handle the case of (shape_i(x)+1)
# for v in node.inputs:
# for idx in range(v.ndim):
# self.get_shape(v, idx)
o_shapes = self.get_node_infer_shape(node)
assert len(o_shapes) == len(node.outputs)
for shps, out in zip(o_shapes, node.outputs):
self.set_shape(out, shps, override=True)
r = o_shapes[node.outputs.index(inp)][r.owner.op.i]
return r
def shape_ir(self, i, r):
"""Return symbolic r.shape[i] for tensor variable r, int i."""
if hasattr(r.type, "broadcastable") and r.type.broadcastable[i]:
......@@ -1050,7 +1078,7 @@ class ShapeFeature(object):
raise TypeError('Unsupported shape element',
s_i, type(s_i), getattr(s_i, 'type', None))
def set_shape(self, r, s):
def set_shape(self, r, s, override=False):
"""Assign the shape `s` to previously un-shaped variable `r`.
Parameters
......@@ -1059,7 +1087,8 @@ class ShapeFeature(object):
s : None or a tuple of symbolic integers
"""
assert r not in self.shape_of, 'r already in shape_of'
if not override:
assert r not in self.shape_of, 'r already in shape_of'
if s is None:
self.shape_of[r] = s
else:
......@@ -1503,12 +1532,17 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
# when i.owner.inputs[0].type == i.owner.outputs[0].type we
# will remove that alloc later
assert i.type.ndim == cmp_op.ndim
get_shape = node.fgraph.shape_feature.get_shape
if (theano.config.experimental.local_alloc_elemwise_assert and
not same_shape(i, cmp_op)):
assert_op = assert_(assert_op,
*[T.eq(i.shape[idx], cmp_op.shape[idx])
for idx in xrange(i.type.ndim)
if not i.type.broadcastable[idx]])
cond = []
for idx in xrange(i.type.ndim):
if not i.type.broadcastable[idx]:
# TODO: same_shape(i, cmp_op, dim_x=idx, dim_y=idx)
i_shp = get_shape(i, idx)
cmp_shp = get_shape(cmp_op, idx)
cond.append(T.eq(i_shp, cmp_shp))
assert_op = assert_(assert_op, *cond)
new_i.append(i.owner.inputs[0])
# Remove Alloc in DimShuffle
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论