提交 c60bd3b2 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Use numpy add ufunc instead of custom code.

上级 e04074a9
...@@ -22,10 +22,6 @@ from theano.tensor.elemwise import DimShuffle ...@@ -22,10 +22,6 @@ from theano.tensor.elemwise import DimShuffle
from theano.tensor.type_other import NoneConst, SliceType, NoneTypeT, make_slice from theano.tensor.type_other import NoneConst, SliceType, NoneTypeT, make_slice
from theano import config from theano import config
if config.cxx:
import theano.gof.cutils # needed to import cutils_ext
from cutils_ext.cutils_ext import inplace_increment
_logger = logging.getLogger("theano.tensor.subtensor") _logger = logging.getLogger("theano.tensor.subtensor")
# Do a lazy import of the sparse module # Do a lazy import of the sparse module
...@@ -2001,36 +1997,10 @@ class AdvancedIncSubtensor1(Op): ...@@ -2001,36 +1997,10 @@ class AdvancedIncSubtensor1(Op):
if self.set_instead_of_inc: if self.set_instead_of_inc:
x[idx] = y x[idx] = y
else: else:
if config.cxx and node.inputs[0].dtype != 'float16': np.add.at(x, idx, y)
increment = inplace_increment
else:
increment = self.inplace_increment1d_slow
increment(x, idx, y)
out[0] = x out[0] = x
def inplace_increment1d_slow(self, x, idx, y):
# If `y` has as many dimensions as `x`, then we want to iterate
# jointly on `x` and `y`. Otherwise, it means `y` should be
# broadcasted to fill all relevant rows of `x`.
assert y.ndim <= x.ndim # Should be guaranteed by `make_node`
if y.ndim == x.ndim:
if len(y) == 1:
# Allow broadcasting of y[0]
y_0 = y[0]
for i in idx:
x[i] += y_0
else:
assert len(y) == len(idx)
j = 0
for i in idx:
x[i] += y[j]
j += 1
else:
for i in idx:
x[i] += y
def infer_shape(self, node, ishapes): def infer_shape(self, node, ishapes):
x, y, ilist = ishapes x, y, ilist = ishapes
return [x] return [x]
...@@ -2246,14 +2216,8 @@ class AdvancedIncSubtensor(Op): ...@@ -2246,14 +2216,8 @@ class AdvancedIncSubtensor(Op):
if self.set_instead_of_inc: if self.set_instead_of_inc:
out[0][inputs[2:]] = inputs[1] out[0][inputs[2:]] = inputs[1]
elif config.cxx:
inplace_increment(out[0], tuple(inputs[2:]), inputs[1])
else: else:
raise NotImplementedError( np.add.at(out[0], tuple(inputs[2:]), inputs[1])
'Could not import inplace_increment, so advanced '
'indexing is disabled. '
'Please make sure that you have a working C++ compiler '
'and that config.cxx is correctly set.')
def infer_shape(self, node, ishapes): def infer_shape(self, node, ishapes):
return [ishapes[0]] return [ishapes[0]]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论