提交 22ce58a2 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #1729 from abergeron/subtensor_const

Lift subtensor constant indexes as constants in the graph rather than keep them op-internal.
......@@ -557,7 +557,7 @@ def get_scalar_constant_value(v):
data = v.data
return numpy_scalar(data)
if v.owner:
if getattr(v, 'owner', None):
if isinstance(v.owner.op, (Alloc, DimShuffle, Rebroadcast,
compile.ops.OutputGuard,
compile.DeepCopyOp)):
......@@ -590,14 +590,10 @@ def get_scalar_constant_value(v):
v.owner.op.perform(v.owner, const, ret)
return ret[0][0]
if isinstance(v.owner.op, theano.tensor.subtensor.Subtensor) and v.ndim == 0:
# This condition depends on Subtensor always embedding constant
# indices in the Op rather than making them inputs to the Apply
# node.
if isinstance(v.owner.inputs[0], TensorConstant) and \
len(v.owner.inputs) == 1:
if isinstance(v.owner.inputs[0], TensorConstant):
cdata = tuple(v.owner.op.get_constant_idx(v.owner.inputs))
try:
return v.owner.inputs[0].data.__getitem__(
tuple(v.owner.op.idx_list))
return v.owner.inputs[0].data.__getitem__(cdata)
except IndexError:
raise IndexError(
str(tuple(v.owner.op.idx_list)) +
......@@ -620,10 +616,12 @@ def get_scalar_constant_value(v):
v.owner.inputs[0].owner.inputs) and
len(v.owner.op.idx_list) == 1):
idx = v.owner.op.idx_list[0]
if isinstance(idx, gof.Type):
idx = get_scalar_constant_value(v.owner.inputs[1])
# Note the '+ 1' is because the first argument to Join is the
# axis.
ret = v.owner.inputs[0].owner.inputs[
v.owner.op.idx_list[0] + 1]
ret = v.owner.inputs[0].owner.inputs[idx + 1]
ret = get_scalar_constant_value(ret)
# join can cast implicitly its input in some case.
return theano._asarray(ret, dtype=v.type.dtype)
......@@ -635,14 +633,13 @@ def get_scalar_constant_value(v):
# We put this check in case there is change in the future
python_all(var.ndim == 0 for var in
v.owner.inputs[0].owner.inputs) and
len(v.owner.op.idx_list) == 1 and
#idx_list can contain Scalar Type object.
isinstance(v.owner.op.idx_list[0], (int, long,
numpy.integer))):
len(v.owner.op.idx_list) == 1):
idx = v.owner.op.idx_list[0]
if isinstance(idx, gof.Type):
idx = get_scalar_constant_value(v.owner.inputs[1])
# Python 2.4 does not support indexing with numpy.integer
# So we cast it.
idx = int(v.owner.op.idx_list[0])
idx = int(idx)
ret = v.owner.inputs[0].owner.inputs[idx]
ret = get_scalar_constant_value(ret)
# MakeVector can cast implicitly its input in some case.
......@@ -658,6 +655,8 @@ def get_scalar_constant_value(v):
op = owner.op
idx_list = op.idx_list
idx = idx_list[0]
if isinstance(idx, gof.Type):
idx = get_scalar_constant_value(owner.inputs[1])
grandparent = leftmost_parent.owner.inputs[0]
gp_broadcastable = grandparent.type.broadcastable
ndim = grandparent.type.ndim
......
......@@ -1396,8 +1396,9 @@ def _check_rows_is_arange_len_labels(rows, labels):
# ShapeOptimizer, but we keep it if ShapeOptimizer is not present
if isinstance(stop.owner.op, subtensor.Subtensor):
shape_subtensor = stop.owner
if list(shape_subtensor.op.idx_list) == [0]:
shape_var, = shape_subtensor.inputs
if shape_subtensor.op.get_constant_idx(shape_subtensor.inputs,
allow_partial=True) == [0]:
shape_var = shape_subtensor.inputs[0]
if shape_var.owner and shape_var.owner.op == tensor.shape:
return shape_var.owner.inputs[0] is labels
else:
......
......@@ -1633,8 +1633,8 @@ def local_useless_subtensor(node):
if not hasattr(node.fgraph, 'shape_feature'):
return
shape_of = node.fgraph.shape_feature.shape_of
node_input_idx = 1
for pos, idx in enumerate(node.op.idx_list):
cdata = node.op.get_constant_idx(node.inputs, allow_partial=True)
for pos, idx in enumerate(cdata):
if not isinstance(idx, slice):
# If idx is not a slice, this means we remove this dimension
# from the output, so the subtensor is not useless
......@@ -1659,8 +1659,8 @@ def local_useless_subtensor(node):
if isinstance(idx.stop, (int, numpy.integer)):
if idx.stop < length_pos_data:
return False
elif isinstance(idx.stop, theano.scalar.Scalar):
length_pos_shape_i = node.inputs[node_input_idx]
elif isinstance(idx.stop, gof.Variable):
length_pos_shape_i = idx.stop
# length_pos is a tensor variable, but length_pos_shape_i
# is a scalar variable. We try to see if they represent
# the same underlying variable.
......@@ -1683,9 +1683,6 @@ def local_useless_subtensor(node):
assert str(length_pos.type.dtype) == "int64"
assert str(length_pos_shape_i.type.dtype) in ["int8", "int16",
"int32", "int64"]
# We already know that start and step are not variables
# and so they don't appear in the input of the node
node_input_idx += 1
# length_pos_shape_i cannot be None
if length_pos_shape_i != length_pos:
......@@ -1745,8 +1742,7 @@ def local_subtensor_lift(node):
return [u.owner.op(*new_inputs)]
if isinstance(u.owner.op, T.Rebroadcast):
# make sure that Subtensor and Rebroadcast only have 1 input/output
assert len(node.inputs) == 1
# make sure that Rebroadcast has only 1 input
assert len(u.owner.inputs) == 1
# Subtensor might reduce dim., adapt broadcast pattern accordingly
......@@ -1768,7 +1764,7 @@ def local_subtensor_lift(node):
new_axis += [(j, u.broadcastable[i])]
j += 1
subt_x = Subtensor(node.op.idx_list)(u.owner.inputs[0])
subt_x = node.op(u.owner.inputs[0], *node.inputs[1:])
rbcast_subt_x = T.Rebroadcast(*new_axis)(subt_x)
return [rbcast_subt_x]
......
......@@ -347,24 +347,56 @@ class Subtensor(Op):
slice_c = None
return slice(slice_a, slice_b, slice_c)
# There is a bug in numpy that results in isinstance(x, int) returning
# False for numpy integers.
# See <http://projects.scipy.org/numpy/ticket/2235>.
elif isinstance(entry, numpy.integer):
return entry
# On Windows 64-bit, shapes are returned as Python long, as they can
# be bigger than what a Python int can hold.
# Shapes should always fit in a numpy.int64, and we support them better
# 2) In Python3, long replaced int. So we must assert it fit in int64.
elif isinstance(entry, (int, long)):
entry64 = numpy.int64(entry)
return entry64
elif isinstance(entry, (int, long, numpy.integer)):
# Disallow the use of python scalars in idx_list
raise TypeError("Python scalar in idx_list."
"Please report this error to theano-dev.")
else:
raise AdvancedIndexingError(Subtensor.e_indextype, entry)
def get_constant_idx(self, inputs, allow_partial=False):
"""
Return the idx_list with constant inputs replaced by their
python scalar equivalent. May raise
`theano.tensor.NotScalarConstantError` if the idx contains
non-constant entries.
If allow_partial is True, then entries that are not constant
will stay as their input variable rather than raising an
exception.
None entries are always left as-is.
Example usage (where v, a are appropriately typed theano variables):
>>> b = a[v, 1:3]
>>> b.owner.op.idx_list
(Scalar(int64), slice(Scalar(int64), Scalar(int64), None))
>>> b.owner.op.get_constant_idx(b.owner.inputs, allow_partial=True)
[v, slice(1, 3, None)]
>>> b.owner.op.get_constant_idx(b.owner.inputs)
NotScalarConstantError: v
"""
real_idx = get_idx_list(inputs, self.idx_list)
def conv(val):
if val is None:
return None
elif isinstance(val, slice):
return slice(conv(val.start),
conv(val.stop),
conv(val.step))
else:
try:
return get_scalar_constant_value(val)
except theano.tensor.NotScalarConstantError:
if allow_partial:
return val
else:
raise
return map(conv, real_idx)
def __init__(self, idx_list):
self.idx_list = tuple(map(self.convert, idx_list))
self.perform_cache_cdata = None
@staticmethod
def my_as_scalar(a):
......@@ -404,31 +436,21 @@ class Subtensor(Op):
% (input.type, expected_type))
# infer the broadcasting pattern
padded = (idx_list
+ [slice(None, None, None)] * (x.type.ndim - len(idx_list)))
padded = (self.get_constant_idx((None,)+inputs, allow_partial=True)
+ [slice(None, None, None)] * (x.type.ndim - len(idx_list)))
broadcastable = []
for i, (p, bc) in enumerate(izip(padded, x.type.broadcastable)):
if isinstance(p, slice):
if bc and p.start in [None, 0]:
# No need to check step when there is only
# one element.
# We could call get_canonical_form_slice() to
# catch more broadcast case. I let this to
# later.
if p.stop is None:
broadcastable.append(bc)
start = p.start
if start is None:
start = 0
if (p.stop is None or
(isinstance(p.stop, (int, numpy.integer)) and
p.stop > start)):
broadcastable.append(True)
continue
try:
if p.start is None:
start = 0
else:
start = get_scalar_constant_value(p.start)
stop = get_scalar_constant_value(p.stop)
if stop > start:
broadcastable.append(True)
continue
except theano.tensor.NotScalarConstantError:
pass
broadcastable.append(False)
return gof.Apply(self,
......@@ -440,18 +462,9 @@ class Subtensor(Op):
out, = out_
x = inputs[0]
# The subtensor (or idx_list) does not depend on the inputs.
# (and cdata was cached on initial call)
if self.perform_cache_cdata is not None:
out[0] = numpy.asarray(x.__getitem__(self.perform_cache_cdata))
return
cdata = get_idx_list(inputs, self.idx_list)
if len(cdata) == 1:
cdata = cdata[0]
# (first call caches cdata here)
if len(inputs) == 1:
self.perform_cache_cdata = cdata
out[0] = numpy.asarray(x.__getitem__(cdata))
......
......@@ -4,7 +4,8 @@ import numpy
import theano
from theano.compat import all, PY3
from theano.scalar import ComplexError, IntegerDivisionError
from theano.scalar import (ComplexError, IntegerDivisionError,
ScalarConstant, int64)
from theano.gof import Constant, Variable
from theano.gof.utils import hashtype
from theano.tensor.utils import hash_from_ndarray
......@@ -348,6 +349,19 @@ class _tensor_py_operators:
def __getitem__(self, args):
if not isinstance(args, tuple):
args = args,
# Convert python literals to theano constants
def conv(a):
if a is None:
return a
elif isinstance(a, slice):
return slice(conv(a.start),
conv(a.stop),
conv(a.step))
elif isinstance(a, (int, long, numpy.integer)):
return ScalarConstant(int64, a)
else:
return a
args = tuple(map(conv, args))
# Determine if advanced indexing is needed or not
# The logic is already in Subtensor.convert: if it succeeds,
# standard indexing is used; if it fails with
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论