提交 ef3bba72 authored 作者: Pascal Lamblin's avatar Pascal Lamblin 提交者: John Salvatier

Enable legacy implementation of perform if safe

上级 15f3c28e
...@@ -8,7 +8,7 @@ from itertools import izip ...@@ -8,7 +8,7 @@ from itertools import izip
from textwrap import dedent from textwrap import dedent
import numpy import numpy
#from copy import copy as python_copy from copy import copy as python_copy
import theano import theano
from theano.compat import PY3 from theano.compat import PY3
...@@ -7375,6 +7375,17 @@ class AdvancedIncSubtensor(Op): ...@@ -7375,6 +7375,17 @@ class AdvancedIncSubtensor(Op):
op. op.
""" """
increment_available = None
@classmethod
def check_increment_available(cls):
if cls.increment_available is None:
try:
from cutils_ext.cutils_ext import (
inplace_increment as increment)
cls.increment_available = True
except ImportError:
cls.increment_available = False
def __init__(self, inplace=False, set_instead_of_inc=False): def __init__(self, inplace=False, set_instead_of_inc=False):
self.inplace = inplace self.inplace = inplace
...@@ -7386,6 +7397,16 @@ class AdvancedIncSubtensor(Op): ...@@ -7386,6 +7397,16 @@ class AdvancedIncSubtensor(Op):
raise NotImplementedError('In place computation is not' raise NotImplementedError('In place computation is not'
' implemented') ' implemented')
# The first time we instanciate an AdvancedIncSubtensor without
# set_instead_of_inc, check if the "increment" function is available
if not set_instead_of_inc:
self.check_increment_available()
# This flag enables the legacy implementation of "perform" for
# advanced_inc_subtensor. This implementation is incorrect in general,
# but gives correct results when no element is indexed more than once.
self.allow_legacy_perform = False
def __hash__(self): def __hash__(self):
return hash((type(self), self.inplace, self.set_instead_of_inc)) return hash((type(self), self.inplace, self.set_instead_of_inc))
...@@ -7399,11 +7420,51 @@ class AdvancedIncSubtensor(Op): ...@@ -7399,11 +7420,51 @@ class AdvancedIncSubtensor(Op):
"inplace=" + str(self.inplace), "inplace=" + str(self.inplace),
" set_instead_of_inc=" + str(self. set_instead_of_inc)) " set_instead_of_inc=" + str(self. set_instead_of_inc))
def __setstate__(self, state):
# We do not want to pickle increment_available, as it depends
# on the machine. Instead, we call check_increment_available on load
self.__dict__.update(state)
if self.set_instead_of_inc:
self.check_increment_available()
def make_node(self, x, y, *inputs): def make_node(self, x, y, *inputs):
x = as_tensor_variable(x) x = as_tensor_variable(x)
y = as_tensor_variable(y) y = as_tensor_variable(y)
return gof.Apply(self, op = self
# If we are incrementing, but the increment compiled function is not
# available, we need to support legacy cases.
if not self.set_instead_of_inc and not self.increment_available:
legacy_conditions = False
if x.ndim == 2 and y.ndim == 1 and len(inputs) == 2:
ind1 = as_tensor_variable(inputs[0])
ind2 = as_tensor_variable(inputs[1])
if ind1.ndim == 1 and ind2.ndim == 1:
if ind1.owner and isinstance(ind1.owner.op, ARange):
legacy_conditions = True
elif isinstance(ind1, Constant):
# Make sure no index is duplicated
val = ind1.value
if numpy.unique(val).size == val.size:
legacy_conditions = True
elif ind2.owner and isinstance(ind2.owner.op, ARange):
legacy_conditions = True
elif isinstance(ind2, Constant):
# Make sure no index is duplicated
val = ind2.value
if numpy.unique(val).size == val.size:
legacy_conditions = True
if legacy_conditions:
op = python_copy(self)
op.allow_legacy_perform = True
else:
raise NotImplementedError(
'Could not import inplace_increment, so some advanced '
'indexing features are disabled. They will be '
'available if you update NumPy to version 1.8 or '
'later, or to the latest development version.')
return gof.Apply(op,
(x, y) + inputs, (x, y) + inputs,
[tensor(dtype=x.type.dtype, [tensor(dtype=x.type.dtype,
broadcastable=x.type.broadcastable)]) broadcastable=x.type.broadcastable)])
...@@ -7420,16 +7481,16 @@ class AdvancedIncSubtensor(Op): ...@@ -7420,16 +7481,16 @@ class AdvancedIncSubtensor(Op):
if self.set_instead_of_inc: if self.set_instead_of_inc:
out[0][inputs[2:]] = inputs[1] out[0][inputs[2:]] = inputs[1]
else: elif self.increment_available:
increment = None
try :
from cutils_ext.cutils_ext import inplace_increment as increment
except ImportError:
raise NotImplementedError('Did not find inplace_increment.'
'Update numpy?')
increment(out[0], tuple(inputs[2:]), inputs[1]) increment(out[0], tuple(inputs[2:]), inputs[1])
elif self.allow_legacy_perform:
out[0][inputs[2:]] += inputs[1]
else:
raise NotImplementedError(
'Could not import inplace_increment, so some advanced '
'indexing features are disabled. They will be '
'available if you update NumPy to version 1.8 or '
'later, or to the latest development version.')
if (numpy.__version__ <= '1.6.1' and if (numpy.__version__ <= '1.6.1' and
out[0].size != numpy.uint32(out[0].size)): out[0].size != numpy.uint32(out[0].size)):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论