提交 22ce58a2 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #1729 from abergeron/subtensor_const

Lift subtensor constant indexes as constants in the graph rather than keep them op-internal.
...@@ -557,7 +557,7 @@ def get_scalar_constant_value(v): ...@@ -557,7 +557,7 @@ def get_scalar_constant_value(v):
data = v.data data = v.data
return numpy_scalar(data) return numpy_scalar(data)
if v.owner: if getattr(v, 'owner', None):
if isinstance(v.owner.op, (Alloc, DimShuffle, Rebroadcast, if isinstance(v.owner.op, (Alloc, DimShuffle, Rebroadcast,
compile.ops.OutputGuard, compile.ops.OutputGuard,
compile.DeepCopyOp)): compile.DeepCopyOp)):
...@@ -590,14 +590,10 @@ def get_scalar_constant_value(v): ...@@ -590,14 +590,10 @@ def get_scalar_constant_value(v):
v.owner.op.perform(v.owner, const, ret) v.owner.op.perform(v.owner, const, ret)
return ret[0][0] return ret[0][0]
if isinstance(v.owner.op, theano.tensor.subtensor.Subtensor) and v.ndim == 0: if isinstance(v.owner.op, theano.tensor.subtensor.Subtensor) and v.ndim == 0:
# This condition depends on Subtensor always embedding constant if isinstance(v.owner.inputs[0], TensorConstant):
# indices in the Op rather than making them inputs to the Apply cdata = tuple(v.owner.op.get_constant_idx(v.owner.inputs))
# node.
if isinstance(v.owner.inputs[0], TensorConstant) and \
len(v.owner.inputs) == 1:
try: try:
return v.owner.inputs[0].data.__getitem__( return v.owner.inputs[0].data.__getitem__(cdata)
tuple(v.owner.op.idx_list))
except IndexError: except IndexError:
raise IndexError( raise IndexError(
str(tuple(v.owner.op.idx_list)) + str(tuple(v.owner.op.idx_list)) +
...@@ -620,10 +616,12 @@ def get_scalar_constant_value(v): ...@@ -620,10 +616,12 @@ def get_scalar_constant_value(v):
v.owner.inputs[0].owner.inputs) and v.owner.inputs[0].owner.inputs) and
len(v.owner.op.idx_list) == 1): len(v.owner.op.idx_list) == 1):
idx = v.owner.op.idx_list[0]
if isinstance(idx, gof.Type):
idx = get_scalar_constant_value(v.owner.inputs[1])
# Note the '+ 1' is because the first argument to Join is the # Note the '+ 1' is because the first argument to Join is the
# axis. # axis.
ret = v.owner.inputs[0].owner.inputs[ ret = v.owner.inputs[0].owner.inputs[idx + 1]
v.owner.op.idx_list[0] + 1]
ret = get_scalar_constant_value(ret) ret = get_scalar_constant_value(ret)
# join can cast implicitly its input in some case. # join can cast implicitly its input in some case.
return theano._asarray(ret, dtype=v.type.dtype) return theano._asarray(ret, dtype=v.type.dtype)
...@@ -635,14 +633,13 @@ def get_scalar_constant_value(v): ...@@ -635,14 +633,13 @@ def get_scalar_constant_value(v):
# We put this check in case there is change in the future # We put this check in case there is change in the future
python_all(var.ndim == 0 for var in python_all(var.ndim == 0 for var in
v.owner.inputs[0].owner.inputs) and v.owner.inputs[0].owner.inputs) and
len(v.owner.op.idx_list) == 1 and len(v.owner.op.idx_list) == 1):
#idx_list can contain Scalar Type object. idx = v.owner.op.idx_list[0]
isinstance(v.owner.op.idx_list[0], (int, long, if isinstance(idx, gof.Type):
numpy.integer))): idx = get_scalar_constant_value(v.owner.inputs[1])
# Python 2.4 does not support indexing with numpy.integer # Python 2.4 does not support indexing with numpy.integer
# So we cast it. # So we cast it.
idx = int(v.owner.op.idx_list[0]) idx = int(idx)
ret = v.owner.inputs[0].owner.inputs[idx] ret = v.owner.inputs[0].owner.inputs[idx]
ret = get_scalar_constant_value(ret) ret = get_scalar_constant_value(ret)
# MakeVector can cast implicitly its input in some case. # MakeVector can cast implicitly its input in some case.
...@@ -658,6 +655,8 @@ def get_scalar_constant_value(v): ...@@ -658,6 +655,8 @@ def get_scalar_constant_value(v):
op = owner.op op = owner.op
idx_list = op.idx_list idx_list = op.idx_list
idx = idx_list[0] idx = idx_list[0]
if isinstance(idx, gof.Type):
idx = get_scalar_constant_value(owner.inputs[1])
grandparent = leftmost_parent.owner.inputs[0] grandparent = leftmost_parent.owner.inputs[0]
gp_broadcastable = grandparent.type.broadcastable gp_broadcastable = grandparent.type.broadcastable
ndim = grandparent.type.ndim ndim = grandparent.type.ndim
......
...@@ -1396,8 +1396,9 @@ def _check_rows_is_arange_len_labels(rows, labels): ...@@ -1396,8 +1396,9 @@ def _check_rows_is_arange_len_labels(rows, labels):
# ShapeOptimizer, but we keep it if ShapeOptimizer is not present # ShapeOptimizer, but we keep it if ShapeOptimizer is not present
if isinstance(stop.owner.op, subtensor.Subtensor): if isinstance(stop.owner.op, subtensor.Subtensor):
shape_subtensor = stop.owner shape_subtensor = stop.owner
if list(shape_subtensor.op.idx_list) == [0]: if shape_subtensor.op.get_constant_idx(shape_subtensor.inputs,
shape_var, = shape_subtensor.inputs allow_partial=True) == [0]:
shape_var = shape_subtensor.inputs[0]
if shape_var.owner and shape_var.owner.op == tensor.shape: if shape_var.owner and shape_var.owner.op == tensor.shape:
return shape_var.owner.inputs[0] is labels return shape_var.owner.inputs[0] is labels
else: else:
......
...@@ -1633,8 +1633,8 @@ def local_useless_subtensor(node): ...@@ -1633,8 +1633,8 @@ def local_useless_subtensor(node):
if not hasattr(node.fgraph, 'shape_feature'): if not hasattr(node.fgraph, 'shape_feature'):
return return
shape_of = node.fgraph.shape_feature.shape_of shape_of = node.fgraph.shape_feature.shape_of
node_input_idx = 1 cdata = node.op.get_constant_idx(node.inputs, allow_partial=True)
for pos, idx in enumerate(node.op.idx_list): for pos, idx in enumerate(cdata):
if not isinstance(idx, slice): if not isinstance(idx, slice):
# If idx is not a slice, this means we remove this dimension # If idx is not a slice, this means we remove this dimension
# from the output, so the subtensor is not useless # from the output, so the subtensor is not useless
...@@ -1659,8 +1659,8 @@ def local_useless_subtensor(node): ...@@ -1659,8 +1659,8 @@ def local_useless_subtensor(node):
if isinstance(idx.stop, (int, numpy.integer)): if isinstance(idx.stop, (int, numpy.integer)):
if idx.stop < length_pos_data: if idx.stop < length_pos_data:
return False return False
elif isinstance(idx.stop, theano.scalar.Scalar): elif isinstance(idx.stop, gof.Variable):
length_pos_shape_i = node.inputs[node_input_idx] length_pos_shape_i = idx.stop
# length_pos is a tensor variable, but length_pos_shape_i # length_pos is a tensor variable, but length_pos_shape_i
# is a scalar variable. We try to see if they represent # is a scalar variable. We try to see if they represent
# the same underlying variable. # the same underlying variable.
...@@ -1683,9 +1683,6 @@ def local_useless_subtensor(node): ...@@ -1683,9 +1683,6 @@ def local_useless_subtensor(node):
assert str(length_pos.type.dtype) == "int64" assert str(length_pos.type.dtype) == "int64"
assert str(length_pos_shape_i.type.dtype) in ["int8", "int16", assert str(length_pos_shape_i.type.dtype) in ["int8", "int16",
"int32", "int64"] "int32", "int64"]
# We already know that start and step are not variables
# and so they don't appear in the input of the node
node_input_idx += 1
# length_pos_shape_i cannot be None # length_pos_shape_i cannot be None
if length_pos_shape_i != length_pos: if length_pos_shape_i != length_pos:
...@@ -1745,8 +1742,7 @@ def local_subtensor_lift(node): ...@@ -1745,8 +1742,7 @@ def local_subtensor_lift(node):
return [u.owner.op(*new_inputs)] return [u.owner.op(*new_inputs)]
if isinstance(u.owner.op, T.Rebroadcast): if isinstance(u.owner.op, T.Rebroadcast):
# make sure that Subtensor and Rebroadcast only have 1 input/output # make sure that Rebroadcast has only 1 input
assert len(node.inputs) == 1
assert len(u.owner.inputs) == 1 assert len(u.owner.inputs) == 1
# Subtensor might reduce dim., adapt broadcast pattern accordingly # Subtensor might reduce dim., adapt broadcast pattern accordingly
...@@ -1768,7 +1764,7 @@ def local_subtensor_lift(node): ...@@ -1768,7 +1764,7 @@ def local_subtensor_lift(node):
new_axis += [(j, u.broadcastable[i])] new_axis += [(j, u.broadcastable[i])]
j += 1 j += 1
subt_x = Subtensor(node.op.idx_list)(u.owner.inputs[0]) subt_x = node.op(u.owner.inputs[0], *node.inputs[1:])
rbcast_subt_x = T.Rebroadcast(*new_axis)(subt_x) rbcast_subt_x = T.Rebroadcast(*new_axis)(subt_x)
return [rbcast_subt_x] return [rbcast_subt_x]
......
...@@ -347,24 +347,56 @@ class Subtensor(Op): ...@@ -347,24 +347,56 @@ class Subtensor(Op):
slice_c = None slice_c = None
return slice(slice_a, slice_b, slice_c) return slice(slice_a, slice_b, slice_c)
# There is a bug in numpy that results in isinstance(x, int) returning elif isinstance(entry, (int, long, numpy.integer)):
# False for numpy integers. # Disallow the use of python scalars in idx_list
# See <http://projects.scipy.org/numpy/ticket/2235>. raise TypeError("Python scalar in idx_list."
elif isinstance(entry, numpy.integer): "Please report this error to theano-dev.")
return entry
# On Windows 64-bit, shapes are returned as Python long, as they can
# be bigger than what a Python int can hold.
# Shapes should always fit in a numpy.int64, and we support them better
# 2) In Python3, long replaced int. So we must assert it fit in int64.
elif isinstance(entry, (int, long)):
entry64 = numpy.int64(entry)
return entry64
else: else:
raise AdvancedIndexingError(Subtensor.e_indextype, entry) raise AdvancedIndexingError(Subtensor.e_indextype, entry)
def get_constant_idx(self, inputs, allow_partial=False):
"""
Return the idx_list with constant inputs replaced by their
python scalar equivalent. May raise
`theano.tensor.NotScalarConstantError` if the idx contains
non-constant entries.
If allow_partial is True, then entries that are not constant
will stay as their input variable rather than raising an
exception.
None entries are always left as-is.
Example usage (where v, a are appropriately typed theano variables):
>>> b = a[v, 1:3]
>>> b.owner.op.idx_list
(Scalar(int64), slice(Scalar(int64), Scalar(int64), None))
>>> b.owner.op.get_constant_idx(b.owner.inputs, allow_partial=True)
[v, slice(1, 3, None)]
>>> b.owner.op.get_constant_idx(b.owner.inputs)
NotScalarConstantError: v
"""
real_idx = get_idx_list(inputs, self.idx_list)
def conv(val):
if val is None:
return None
elif isinstance(val, slice):
return slice(conv(val.start),
conv(val.stop),
conv(val.step))
else:
try:
return get_scalar_constant_value(val)
except theano.tensor.NotScalarConstantError:
if allow_partial:
return val
else:
raise
return map(conv, real_idx)
def __init__(self, idx_list): def __init__(self, idx_list):
self.idx_list = tuple(map(self.convert, idx_list)) self.idx_list = tuple(map(self.convert, idx_list))
self.perform_cache_cdata = None
@staticmethod @staticmethod
def my_as_scalar(a): def my_as_scalar(a):
...@@ -404,31 +436,21 @@ class Subtensor(Op): ...@@ -404,31 +436,21 @@ class Subtensor(Op):
% (input.type, expected_type)) % (input.type, expected_type))
# infer the broadcasting pattern # infer the broadcasting pattern
padded = (idx_list padded = (self.get_constant_idx((None,)+inputs, allow_partial=True)
+ [slice(None, None, None)] * (x.type.ndim - len(idx_list))) + [slice(None, None, None)] * (x.type.ndim - len(idx_list)))
broadcastable = [] broadcastable = []
for i, (p, bc) in enumerate(izip(padded, x.type.broadcastable)): for i, (p, bc) in enumerate(izip(padded, x.type.broadcastable)):
if isinstance(p, slice): if isinstance(p, slice):
if bc and p.start in [None, 0]: if bc and p.start in [None, 0]:
# No need to check step when there is only start = p.start
# one element. if start is None:
# We could call get_canonical_form_slice() to start = 0
# catch more broadcast case. I let this to if (p.stop is None or
# later. (isinstance(p.stop, (int, numpy.integer)) and
if p.stop is None: p.stop > start)):
broadcastable.append(bc) broadcastable.append(True)
continue continue
try:
if p.start is None:
start = 0
else:
start = get_scalar_constant_value(p.start)
stop = get_scalar_constant_value(p.stop)
if stop > start:
broadcastable.append(True)
continue
except theano.tensor.NotScalarConstantError:
pass
broadcastable.append(False) broadcastable.append(False)
return gof.Apply(self, return gof.Apply(self,
...@@ -440,18 +462,9 @@ class Subtensor(Op): ...@@ -440,18 +462,9 @@ class Subtensor(Op):
out, = out_ out, = out_
x = inputs[0] x = inputs[0]
# The subtensor (or idx_list) does not depend on the inputs.
# (and cdata was cached on initial call)
if self.perform_cache_cdata is not None:
out[0] = numpy.asarray(x.__getitem__(self.perform_cache_cdata))
return
cdata = get_idx_list(inputs, self.idx_list) cdata = get_idx_list(inputs, self.idx_list)
if len(cdata) == 1: if len(cdata) == 1:
cdata = cdata[0] cdata = cdata[0]
# (first call caches cdata here)
if len(inputs) == 1:
self.perform_cache_cdata = cdata
out[0] = numpy.asarray(x.__getitem__(cdata)) out[0] = numpy.asarray(x.__getitem__(cdata))
......
...@@ -4,7 +4,8 @@ import numpy ...@@ -4,7 +4,8 @@ import numpy
import theano import theano
from theano.compat import all, PY3 from theano.compat import all, PY3
from theano.scalar import ComplexError, IntegerDivisionError from theano.scalar import (ComplexError, IntegerDivisionError,
ScalarConstant, int64)
from theano.gof import Constant, Variable from theano.gof import Constant, Variable
from theano.gof.utils import hashtype from theano.gof.utils import hashtype
from theano.tensor.utils import hash_from_ndarray from theano.tensor.utils import hash_from_ndarray
...@@ -348,6 +349,19 @@ class _tensor_py_operators: ...@@ -348,6 +349,19 @@ class _tensor_py_operators:
def __getitem__(self, args): def __getitem__(self, args):
if not isinstance(args, tuple): if not isinstance(args, tuple):
args = args, args = args,
# Convert python literals to theano constants
def conv(a):
if a is None:
return a
elif isinstance(a, slice):
return slice(conv(a.start),
conv(a.stop),
conv(a.step))
elif isinstance(a, (int, long, numpy.integer)):
return ScalarConstant(int64, a)
else:
return a
args = tuple(map(conv, args))
# Determine if advanced indexing is needed or not # Determine if advanced indexing is needed or not
# The logic is already in Subtensor.convert: if it succeeds, # The logic is already in Subtensor.convert: if it succeeds,
# standard indexing is used; if it fails with # standard indexing is used; if it fails with
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论