提交 5f9b2a8c authored 作者: John Salvatier's avatar John Salvatier

simplified increment check

上级 4c7c7bc3
...@@ -25,6 +25,10 @@ from theano.printing import pprint, min_informative_str ...@@ -25,6 +25,10 @@ from theano.printing import pprint, min_informative_str
from theano.tensor.utils import hash_from_ndarray from theano.tensor.utils import hash_from_ndarray
import theano.gof.cutils #needed to import cutils_ext import theano.gof.cutils #needed to import cutils_ext
try:
from cutils_ext.cutils_ext import inplace_increment
except ImportError:
inplace_increment = None
# We use these exceptions as well. # We use these exceptions as well.
from theano.scalar import ComplexError, IntegerDivisionError from theano.scalar import ComplexError, IntegerDivisionError
...@@ -7136,9 +7140,8 @@ class AdvancedIncSubtensor1(Op): ...@@ -7136,9 +7140,8 @@ class AdvancedIncSubtensor1(Op):
if self.set_instead_of_inc: if self.set_instead_of_inc:
x[idx] = y x[idx] = y
else: else:
try : increment = inplace_increment
from cutils_ext.cutils_ext import inplace_increment as increment if increment is None:
except ImportError:
increment = self.inplace_increment1d_slow increment = self.inplace_increment1d_slow
increment(x,idx, y) increment(x,idx, y)
...@@ -7376,17 +7379,6 @@ class AdvancedIncSubtensor(Op): ...@@ -7376,17 +7379,6 @@ class AdvancedIncSubtensor(Op):
op. op.
""" """
increment_available = None
@classmethod
def check_increment_available(cls):
if cls.increment_available is None:
try:
from cutils_ext.cutils_ext import (
inplace_increment as increment)
cls.increment_available = True
except ImportError:
cls.increment_available = False
def __init__(self, inplace=False, set_instead_of_inc=False): def __init__(self, inplace=False, set_instead_of_inc=False):
self.inplace = inplace self.inplace = inplace
...@@ -7398,16 +7390,13 @@ class AdvancedIncSubtensor(Op): ...@@ -7398,16 +7390,13 @@ class AdvancedIncSubtensor(Op):
raise NotImplementedError('In place computation is not' raise NotImplementedError('In place computation is not'
' implemented') ' implemented')
# The first time we instanciate an AdvancedIncSubtensor without
# set_instead_of_inc, check if the "increment" function is available
if not set_instead_of_inc:
self.check_increment_available()
# This flag enables the legacy implementation of "perform" for
# advanced_inc_subtensor. This implementation is incorrect in general,
# but gives correct results when no element is indexed more than once.
self.allow_legacy_perform = False self.allow_legacy_perform = False
@classmethod
@property
def increment_available():
return inplace_increment is not None
def __hash__(self): def __hash__(self):
return hash((type(self), self.inplace, self.set_instead_of_inc)) return hash((type(self), self.inplace, self.set_instead_of_inc))
...@@ -7422,11 +7411,7 @@ class AdvancedIncSubtensor(Op): ...@@ -7422,11 +7411,7 @@ class AdvancedIncSubtensor(Op):
" set_instead_of_inc=" + str(self. set_instead_of_inc)) " set_instead_of_inc=" + str(self. set_instead_of_inc))
def __setstate__(self, state): def __setstate__(self, state):
# We do not want to pickle increment_available, as it depends
# on the machine. Instead, we call check_increment_available on load
self.__dict__.update(state) self.__dict__.update(state)
if self.set_instead_of_inc:
self.check_increment_available()
def make_node(self, x, y, *inputs): def make_node(self, x, y, *inputs):
x = as_tensor_variable(x) x = as_tensor_variable(x)
...@@ -7482,8 +7467,8 @@ class AdvancedIncSubtensor(Op): ...@@ -7482,8 +7467,8 @@ class AdvancedIncSubtensor(Op):
if self.set_instead_of_inc: if self.set_instead_of_inc:
out[0][inputs[2:]] = inputs[1] out[0][inputs[2:]] = inputs[1]
elif self.increment_available: elif inplace_increment is not None:
increment(out[0], tuple(inputs[2:]), inputs[1]) inplace_increment(out[0], tuple(inputs[2:]), inputs[1])
elif self.allow_legacy_perform: elif self.allow_legacy_perform:
out[0][inputs[2:]] += inputs[1] out[0][inputs[2:]] += inputs[1]
else: else:
......
...@@ -3675,7 +3675,6 @@ class TestIncSubtensor1(unittest.TestCase): ...@@ -3675,7 +3675,6 @@ class TestIncSubtensor1(unittest.TestCase):
# also tests set_subtensor # also tests set_subtensor
def setUp(self): def setUp(self):
AdvancedIncSubtensor.check_increment_available()
self.s = iscalar() self.s = iscalar()
self.v = fvector() self.v = fvector()
self.m = dmatrix() self.m = dmatrix()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论