提交 c5b120c8 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #3548 from nouiz/extra_node_infer_shape

[ENH,REGRESSION FIX] Fix regression that introduce extra node related to infer shape
......@@ -921,6 +921,80 @@ class ShapeFeature(object):
constants?? That would be confusing.
"""
def get_node_infer_shape(self, node):
try:
shape_infer = node.op.infer_shape
except AttributeError:
shape_infer = self.default_infer_shape
try:
o_shapes = shape_infer(node,
[self.shape_of[r] for r in node.inputs])
except ShapeError:
o_shapes = self.default_infer_shape(node, [self.shape_of[r] for
r in node.inputs])
except NotImplementedError as e:
raise NotImplementedError(
'Code called by infer_shape failed raising a '
'NotImplementedError. Raising NotImplementedError to '
'indicate that a shape cannot be computed is no longer '
'supported, and one should now use tensor.ShapeError '
'instead. The original exception message is: %s' % e)
except Exception as e:
msg = ('Failed to infer_shape from Op %s.\nInput shapes: '
'%s\nException encountered during infer_shape: '
'%s\nException message: %s\nTraceback: %s') % (
node.op, [self.shape_of[r] for r in node.inputs],
type(e), str(e), traceback.format_exc())
if config.on_shape_error == "raise":
raise Exception(msg)
else:
_logger.warning(msg)
o_shapes = self.default_infer_shape(
node, [self.shape_of[r] for r in node.inputs])
return o_shapes
def get_shape(self, var, idx):
""" Optimization can call this to get the current shape_i
It is better to call this then use directly shape_of[var][idx]
as this method should update shape_of if needed.
TODO: Up to now, we don't update it in all cases. Update in all cases.
"""
r = self.shape_of[var][idx]
if (r.owner and
isinstance(r.owner.op, Shape_i) and
r.owner.inputs[0] not in var.fgraph.variables):
assert var.owner
node = var.owner
# TODO recur on inputs
# Need to time this to don't have it too slow.
# Make sure to handle the case of (shape_i(x)+1)
# see https://github.com/Theano/Theano/issues/3560
o_shapes = self.get_node_infer_shape(node)
assert len(o_shapes) == len(node.outputs)
# Only change the variables and dimensions that would introduce
# extra computation
for new_shps, out in zip(o_shapes, node.outputs):
if not hasattr(out, 'ndim'):
continue
merged_shps = list(self.shape_of[out])
changed = False
for i in range(out.ndim):
n_r = merged_shps[i]
if (n_r.owner and
isinstance(n_r.owner.op, Shape_i) and
n_r.owner.inputs[0] not in var.fgraph.variables):
changed = True
merged_shps[i] = new_shps[i]
if changed:
self.set_shape(out, merged_shps, override=True)
r = self.shape_of[var][idx]
return r
def shape_ir(self, i, r):
"""Return symbolic r.shape[i] for tensor variable r, int i."""
......@@ -1017,16 +1091,20 @@ class ShapeFeature(object):
raise TypeError('Unsupported shape element',
s_i, type(s_i), getattr(s_i, 'type', None))
def set_shape(self, r, s):
def set_shape(self, r, s, override=False):
"""Assign the shape `s` to previously un-shaped variable `r`.
Parameters
----------
r : a variable
s : None or a tuple of symbolic integers
override : If False, it mean r is a new object in the fgraph.
If True, it mean r is already in the fgraph and we want to
override its shape.
"""
assert r not in self.shape_of, 'r already in shape_of'
if not override:
assert r not in self.shape_of, 'r already in shape_of'
if s is None:
self.shape_of[r] = s
else:
......@@ -1207,36 +1285,7 @@ class ShapeFeature(object):
# make sure we have shapes for the inputs
self.init_r(r)
try:
shape_infer = node.op.infer_shape
except AttributeError:
shape_infer = self.default_infer_shape
try:
o_shapes = shape_infer(node,
[self.shape_of[r] for r in node.inputs])
except ShapeError:
o_shapes = self.default_infer_shape(node, [self.shape_of[r] for
r in node.inputs])
except NotImplementedError as e:
raise NotImplementedError(
'Code called by infer_shape failed raising a '
'NotImplementedError. Raising NotImplementedError to '
'indicate that a shape cannot be computed is no longer '
'supported, and one should now use tensor.ShapeError '
'instead. The original exception message is: %s' % e)
except Exception as e:
msg = ('Failed to infer_shape from Op %s.\nInput shapes: '
'%s\nException encountered during infer_shape: '
'%s\nException message: %s\nTraceback: %s') % (
node.op, [self.shape_of[r] for r in node.inputs],
type(e), str(e), traceback.format_exc())
if config.on_shape_error == "raise":
raise Exception(msg)
else:
_logger.warning(msg)
o_shapes = self.default_infer_shape(
node, [self.shape_of[r] for r in node.inputs])
o_shapes = self.get_node_infer_shape(node)
# this is packed information
# an element of o_shapes is either None or a tuple
......@@ -1499,12 +1548,17 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
# when i.owner.inputs[0].type == i.owner.outputs[0].type we
# will remove that alloc later
assert i.type.ndim == cmp_op.ndim
if (theano.config.experimental.local_alloc_elemwise_assert and
not same_shape(i, cmp_op)):
assert_op = assert_(assert_op,
*[T.eq(i.shape[idx], cmp_op.shape[idx])
for idx in xrange(i.type.ndim)
if not i.type.broadcastable[idx]])
get_shape = node.fgraph.shape_feature.get_shape
if theano.config.experimental.local_alloc_elemwise_assert:
cond = []
for idx in xrange(i.type.ndim):
if (not i.type.broadcastable[idx] and
not same_shape(i, cmp_op, idx, idx)):
i_shp = get_shape(i, idx)
cmp_shp = get_shape(cmp_op, idx)
cond.append(T.eq(i_shp, cmp_shp))
if cond:
assert_op = assert_(assert_op, *cond)
new_i.append(i.owner.inputs[0])
# Remove Alloc in DimShuffle
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论