提交 518b6f43 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Implement a get_constant_idx helper method on Subtensor and use it when we need…

Implement a get_constant_idx helper method on Subtensor and use it when we need the constant values in opts. Also fix the problem with nnet opts when ShapeFeature is off.
上级 c49564d2
...@@ -591,20 +591,7 @@ def get_scalar_constant_value(v): ...@@ -591,20 +591,7 @@ def get_scalar_constant_value(v):
return ret[0][0] return ret[0][0]
if isinstance(v.owner.op, theano.tensor.subtensor.Subtensor) and v.ndim == 0: if isinstance(v.owner.op, theano.tensor.subtensor.Subtensor) and v.ndim == 0:
if isinstance(v.owner.inputs[0], TensorConstant): if isinstance(v.owner.inputs[0], TensorConstant):
indices = list(reversed(list(v.owner.inputs[1:]))) cdata = v.owner.op.get_constant_idx(v.owner.inputs)
def conv(e):
if isinstance(e, gof.Type):
return get_scalar_constant_value(indices.pop())
elif isinstance(e, slice):
return slice(conv(e.start),
conv(e.stop),
conv(e.step))
elif isinstance(e, (int, long, numpy.integer)):
return int(e)
else:
raise NotScalarConstantError(v)
cdata = tuple(map(conv, v.owner.op.idx_list))
try: try:
return v.owner.inputs[0].data.__getitem__(cdata) return v.owner.inputs[0].data.__getitem__(cdata)
except IndexError: except IndexError:
......
...@@ -1396,8 +1396,9 @@ def _check_rows_is_arange_len_labels(rows, labels): ...@@ -1396,8 +1396,9 @@ def _check_rows_is_arange_len_labels(rows, labels):
# ShapeOptimizer, but we keep it if ShapeOptimizer is not present # ShapeOptimizer, but we keep it if ShapeOptimizer is not present
if isinstance(stop.owner.op, subtensor.Subtensor): if isinstance(stop.owner.op, subtensor.Subtensor):
shape_subtensor = stop.owner shape_subtensor = stop.owner
if list(shape_subtensor.op.idx_list) == [0]: if shape_subtensor.op.get_constant_idx(shape_subtensor.inputs,
shape_var, = shape_subtensor.inputs allow_partial=True) == [0]:
shape_var = shape_subtensor.inputs[0]
if shape_var.owner and shape_var.owner.op == tensor.shape: if shape_var.owner and shape_var.owner.op == tensor.shape:
return shape_var.owner.inputs[0] is labels return shape_var.owner.inputs[0] is labels
else: else:
......
...@@ -1625,32 +1625,17 @@ def local_useless_subtensor(node): ...@@ -1625,32 +1625,17 @@ def local_useless_subtensor(node):
if not hasattr(node.fgraph, 'shape_feature'): if not hasattr(node.fgraph, 'shape_feature'):
return return
shape_of = node.fgraph.shape_feature.shape_of shape_of = node.fgraph.shape_feature.shape_of
idx_vals = get_idx_list(node.inputs, node.op.idx_list) cdata = node.op.get_constant_idx(node.inputs, allow_partial=True)
for pos, idx in enumerate(idx_vals): for pos, idx in enumerate(cdata):
if not isinstance(idx, slice): if not isinstance(idx, slice):
# If idx is not a slice, this means we remove this dimension # If idx is not a slice, this means we remove this dimension
# from the output, so the subtensor is not useless # from the output, so the subtensor is not useless
return False return False
if idx.start not in [0, None]:
# Grab the values for start/stop/step
try:
start = get_scalar_constant_value(idx.start)
except NotScalarConstantError:
start = idx.start
try:
stop = get_scalar_constant_value(idx.stop)
except NotScalarConstantError:
stop = idx.stop
try:
step = get_scalar_constant_value(idx.step)
except NotScalarConstantError:
step = idx.step
if start not in [0, None]:
# If the start of the slice is different from 0, or is a # If the start of the slice is different from 0, or is a
# variable, then we assume the subtensor is not useless # variable, then we assume the subtensor is not useless
return False return False
if step not in [1, None]: if idx.step not in [1, None]:
# If we are going backwards, or skipping elements, then this # If we are going backwards, or skipping elements, then this
# is not a useless subtensor # is not a useless subtensor
return False return False
...@@ -1663,11 +1648,11 @@ def local_useless_subtensor(node): ...@@ -1663,11 +1648,11 @@ def local_useless_subtensor(node):
except NotScalarConstantError: except NotScalarConstantError:
pass pass
if isinstance(stop, (int, numpy.integer)): if isinstance(idx.stop, (int, numpy.integer)):
if stop < length_pos_data: if stop < length_pos_data:
return False return False
elif isinstance(stop, gof.Variable): elif isinstance(idx.stop, gof.Variable):
length_pos_shape_i = stop length_pos_shape_i = idx.stop
# length_pos is a tensor variable, but length_pos_shape_i # length_pos is a tensor variable, but length_pos_shape_i
# is a scalar variable. We try to see if they represent # is a scalar variable. We try to see if they represent
# the same underlying variable. # the same underlying variable.
...@@ -1694,7 +1679,7 @@ def local_useless_subtensor(node): ...@@ -1694,7 +1679,7 @@ def local_useless_subtensor(node):
# length_pos_shape_i cannot be None # length_pos_shape_i cannot be None
if length_pos_shape_i != length_pos: if length_pos_shape_i != length_pos:
return False return False
elif stop is None: elif idx.stop is None:
pass pass
else: else:
return False return False
......
...@@ -362,6 +362,37 @@ class Subtensor(Op): ...@@ -362,6 +362,37 @@ class Subtensor(Op):
else: else:
raise AdvancedIndexingError(Subtensor.e_indextype, entry) raise AdvancedIndexingError(Subtensor.e_indextype, entry)
def get_constant_idx(self, inputs, allow_partial=False):
"""
Return the idx_list with constant inputs replaced by their
python scalar equivalent. May raise
`theano.tensor.NotScalarConstantError` if the idx contains
non-constant entries.
If allow_partial is True, then entries that are not constant
will stay as their input variable rather than raising an
exception.
None entries are always left as-is.
"""
real_idx = get_idx_list(inputs, self.idx_list)
def conv(val):
if val is None:
return None
elif isinstance(val, slice):
return slice(conv(val.start),
conv(val.stop),
conv(val.step))
else:
try:
return get_scalar_constant_value(val)
except theano.tensor.NotScalarConstantError:
if allow_partial:
return val
else:
raise
return map(conv, real_idx)
def __init__(self, idx_list): def __init__(self, idx_list):
self.idx_list = tuple(map(self.convert, idx_list)) self.idx_list = tuple(map(self.convert, idx_list))
self.perform_cache_cdata = None self.perform_cache_cdata = None
...@@ -404,27 +435,16 @@ class Subtensor(Op): ...@@ -404,27 +435,16 @@ class Subtensor(Op):
% (input.type, expected_type)) % (input.type, expected_type))
# infer the broadcasting pattern # infer the broadcasting pattern
padded = (idx_list padded = (self.get_constant_idx((None,)+inputs, allow_partial=True)
+ [slice(None, None, None)] * (x.type.ndim - len(idx_list))) + [slice(None, None, None)] * (x.type.ndim - len(idx_list)))
idx_padded = get_idx_list((None,)+inputs, padded)
broadcastable = [] broadcastable = []
for i, (p, bc) in enumerate(izip(idx_padded, x.type.broadcastable)): for i, (p, bc) in enumerate(izip(padded, x.type.broadcastable)):
if isinstance(p, slice): if isinstance(p, slice):
# figure out the value of start and stop (if they are constant) if bc and p.start in [None, 0]:
try: start = p.start
if p.start is None: if start is None:
start = 0 start = 0
else: if p.stop is None or p.stop > start:
start = get_scalar_constant_value(p.start)
if p.stop is None:
stop = 1
else:
stop = get_scalar_constant_value(p.stop)
except theano.tensor.NotScalarConstantError:
start = None
stop = None
if bc and start == 0 and stop > start:
broadcastable.append(True) broadcastable.append(True)
continue continue
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论