提交 b18616bd authored 作者: Benjamin Scellier's avatar Benjamin Scellier

remove legacy conditions of advanced_inc_subtensor

上级 55596d1b
......@@ -6,8 +6,6 @@ import subprocess
import sys
from locale import getpreferredencoding
import numpy
from theano import config
from theano.compat import decode, decode_with
from theano.configdefaults import local_bitwidth
......
from __future__ import absolute_import, print_function, division
from copy import copy
import sys
from textwrap import dedent
import warnings
......@@ -13,23 +12,19 @@ import theano
from theano.compat import izip
from theano.gradient import DisconnectedType
from theano import gof
from theano.gof import Apply, Constant, hashtype, Op, Type, MethodNotDefined
from theano.gof import Apply, hashtype, Op, Type, MethodNotDefined
from theano.printing import pprint
from theano import scalar as scal
from theano.tensor.basic import alloc
from theano.tensor.basic import (addbroadcast, clip, get_scalar_constant_value,
ARange, TensorType, NotScalarConstantError)
TensorType, NotScalarConstantError)
from theano.tensor.elemwise import DimShuffle
from theano.tensor.type_other import NoneConst, SliceType, NoneTypeT, make_slice
from theano import config
inplace_increment = None
if config.cxx:
import theano.gof.cutils # needed to import cutils_ext
try:
from cutils_ext.cutils_ext import inplace_increment
except ImportError:
pass
from cutils_ext.cutils_ext import inplace_increment
_logger = logging.getLogger("theano.tensor.subtensor")
......@@ -2001,8 +1996,9 @@ class AdvancedIncSubtensor1(Op):
if self.set_instead_of_inc:
x[idx] = y
else:
increment = inplace_increment
if increment is None:
if config.cxx:
increment = inplace_increment
else:
increment = self.inplace_increment1d_slow
increment(x, idx, y)
......@@ -2193,12 +2189,6 @@ advanced_subtensor = AdvancedSubtensor()
class AdvancedIncSubtensor(Op):
"""
Increments a subtensor using advanced indexing.
Notes
-----
We need the numpy.inplace_increment() function currently
numpy's PR 326 to be able to make an inplace version of this op.
"""
__props__ = ("inplace", "set_instead_of_inc")
......@@ -2213,8 +2203,6 @@ class AdvancedIncSubtensor(Op):
raise NotImplementedError('In place computation is not'
' implemented')
self.allow_legacy_perform = False
def __str__(self):
return "%s{%s, %s}" % (self.__class__.__name__,
"inplace=" + str(self.inplace),
......@@ -2225,46 +2213,12 @@ class AdvancedIncSubtensor(Op):
x = theano.tensor.as_tensor_variable(x)
y = theano.tensor.as_tensor_variable(y)
op = self
# If we are incrementing, but the increment compiled function is not
# available, we need to support legacy cases.
if not self.set_instead_of_inc and inplace_increment is None:
legacy_conditions = False
if x.ndim == 2 and y.ndim == 1 and len(inputs) == 2:
ind1 = theano.tensor.as_tensor_variable(inputs[0])
ind2 = theano.tensor.as_tensor_variable(inputs[1])
if ind1.ndim == 1 and ind2.ndim == 1:
if ind1.owner and isinstance(ind1.owner.op, ARange):
legacy_conditions = True
elif isinstance(ind1, Constant):
# Make sure no index is duplicated
val = ind1.value
if numpy.unique(val).size == val.size:
legacy_conditions = True
elif ind2.owner and isinstance(ind2.owner.op, ARange):
legacy_conditions = True
elif isinstance(ind2, Constant):
# Make sure no index is duplicated
val = ind2.value
if numpy.unique(val).size == val.size:
legacy_conditions = True
if legacy_conditions:
op = copy(self)
op.allow_legacy_perform = True
else:
raise NotImplementedError(
'Could not import inplace_increment, so some advanced '
'indexing features are disabled. They will be '
'available if you update NumPy to version 1.8 or '
'later, or to the latest development version. '
'You may need to clear the cache (theano-cache clear) '
'afterwards.')
new_inputs = []
for inp in inputs:
if isinstance(inp, (list, tuple)):
inp = theano.tensor.as_tensor_variable(inp)
new_inputs.append(inp)
return gof.Apply(op,
return gof.Apply(self,
(x, y) + tuple(new_inputs),
[theano.tensor.tensor(
dtype=x.type.dtype,
......@@ -2282,18 +2236,14 @@ class AdvancedIncSubtensor(Op):
if self.set_instead_of_inc:
out[0][inputs[2:]] = inputs[1]
elif inplace_increment is not None:
elif config.cxx:
inplace_increment(out[0], tuple(inputs[2:]), inputs[1])
elif self.allow_legacy_perform:
out[0][inputs[2:]] += inputs[1]
else:
raise NotImplementedError(
'Could not import inplace_increment, so some advanced '
'indexing features are disabled. They will be '
'available if you update NumPy to version 1.8 or '
'later, or to the latest development version. '
'You may need to clear the cache (theano-cache clear) '
'afterwards.')
'Could not import inplace_increment, so advanced '
'indexing is disabled. '
'Please make sure that you have a working C++ compiler '
'and that config.cxx is correctly set.')
def infer_shape(self, node, ishapes):
return [ishapes[0]]
......
......@@ -30,7 +30,8 @@ from theano.tensor.subtensor import (AdvancedIncSubtensor,
advanced_set_subtensor,
advanced_set_subtensor1,
get_canonical_form_slice, inc_subtensor,
inplace_increment, set_subtensor)
set_subtensor)
from theano.tensor.tests.test_basic import inplace_func, rand, randint_ranged
from theano.tests import unittest_tools as utt
from theano.tests.unittest_tools import attr
......@@ -1340,12 +1341,6 @@ class TestIncSubtensor1(unittest.TestCase):
utt.assert_allclose(out1val, out2val)
inplace_increment_missing = SkipTest(
"inc_subtensor with advanced indexing not enabled. "
"Installing NumPy 1.8 or the latest development version "
"should make that feature available.")
class TestAdvancedSubtensor(unittest.TestCase):
# test inc_subtensor
# also tests set_subtensor
......@@ -1494,8 +1489,8 @@ class TestAdvancedSubtensor(unittest.TestCase):
utt.assert_allclose(rval, aval)
def test_inc_adv_subtensor_w_2vec(self):
if inplace_increment is None:
raise inplace_increment_missing
if not config.cxx:
raise SkipTest('config.cxx empty')
subt = self.m[self.ix1, self.ix12]
a = inc_subtensor(subt, subt)
......@@ -1515,8 +1510,8 @@ class TestAdvancedSubtensor(unittest.TestCase):
[.5, .3 * 2, .15]]), aval
def test_inc_adv_subtensor_with_broadcasting(self):
if inplace_increment is None:
raise inplace_increment_missing
if not config.cxx:
raise SkipTest('config.cxx empty')
inc = dscalar()
a = inc_subtensor(self.m[self.ix1, self.ix12], inc)
......@@ -1538,8 +1533,8 @@ class TestAdvancedSubtensor(unittest.TestCase):
assert numpy.allclose(gval, 3.0), gval
def test_inc_adv_subtensor1_with_broadcasting(self):
if inplace_increment is None:
raise inplace_increment_missing
if not config.cxx:
raise SkipTest('config.cxx empty')
inc = dscalar()
a = inc_subtensor(self.m[self.ix1], inc)
......@@ -1560,8 +1555,8 @@ class TestAdvancedSubtensor(unittest.TestCase):
assert numpy.allclose(gval, 9.0), gval
def test_inc_adv_subtensor_with_index_broadcasting(self):
if inplace_increment is None:
raise inplace_increment_missing
if not config.cxx:
raise SkipTest('config.cxx empty')
a = inc_subtensor(self.m[self.ix1, self.ix2], 2.1)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论