提交 9d94cece authored 作者: jsalvatier's avatar jsalvatier 提交者: John Salvatier

initial try at making advanced subtensor and subtensor inc work in the general case

上级 69e23761
......@@ -7187,7 +7187,51 @@ class AdvancedIncSubtensor1(Op):
advanced_inc_subtensor1 = AdvancedIncSubtensor1()
from itertools import groupby, chain
def simpleindex(a):
try:
return as_tensor_variable(a).ndim == 0
except:
return True
def simple_broadcastable(a, idx):
def replace_slice(v):
if isinstance(v, slice):
return slice(None,None)
if simpleindex(v):
return 0
return v
newidx = tuple(map(replace_slice, idx))
fakeshape = [bc + 1 for bc in a.broadcastable]
retshape = numpy.empty(fakeshape)[newidx].shape
return tuple([dim == 1 for dim in retshape])
from __builtin__ import sum as concat
def concat(ls):
r = []
map(r.extend, ls)
return r
def advanced_broadcastable(a, idx):
chunks = list(groupby(idx, simpleindex))
chunks = [(s, list(c)) for s,c in chunks]
if len(chunks) > 3:
chunks = [concat(c for s, c in chunks if s), concat(c for s, c in chunks if not s)]
def getbroad((simple, c)):
if simple:
return simple_broadcastable(a, c)
else:
return as_tensor_variable(c[0]).broadcastable
return concat(map(getbroad, chunks))
class AdvancedSubtensor(Op):
"""Return a subtensor copy, using advanced indexing.
"""
......@@ -7204,37 +7248,16 @@ class AdvancedSubtensor(Op):
def __str__(self):
return self.__class__.__name__
def make_node(self, x, *inputs):
def make_node(self, x, *index):
x = as_tensor_variable(x)
# FIXME
# should be replaced with something that includes support for None and slices
# Note (9 Jul 2012): what does this 'FIXME' mean? Possibly that the
# current implementation must be generalized? Please specify.
if x.ndim == 2 and len(inputs) == 2:
ind1 = as_tensor_variable(inputs[0])
ind2 = as_tensor_variable(inputs[1])
if (not (ind1.type.dtype.startswith('int') or
ind1.type.dtype.startswith('uint'))):
raise TypeError(
'the indices into a matrix must be int or uint. It is ',
ind1.type.dtype)
if (not (ind2.type.dtype.startswith('int') or
ind2.type.dtype.startswith('uint'))):
raise TypeError(
'the indices into a matrix must be int or uint. It is ',
ind2.type.dtype)
if ind1.ndim == 1 and ind2.ndim == 1:
return gof.Apply(self,
(x, ind1, ind2),
[tensor(dtype=x.type.dtype,
broadcastable=[False])])
raise NotImplementedError(
'Advanced indexing of x (of dimension %i) with these argument'
' dimensions (%s) not supported yet'
% (x.ndim, ','.join(str(input.ndim) for input in inputs)))
raise NotImplementedError(
'Advanced indexing of x with arguments (%s) not supported yet'
% ','.join(str(input) for input in inputs))
return gof.Apply(self,
(x,) + tuple(map(as_tensor_variable, index)),
[tensor(dtype = x.type.dtype,
broadcastable = advanced_broadcastable(x, index) )])
def R_op(self, inputs, eval_points):
if eval_points[0] is None:
......@@ -7326,24 +7349,10 @@ class AdvancedIncSubtensor(Op):
x = as_tensor_variable(x)
y = as_tensor_variable(y)
if x.ndim == 2 and y.ndim == 1 and len(inputs) == 2:
ind1 = as_tensor_variable(inputs[0])
ind2 = as_tensor_variable(inputs[1])
if ind1.ndim == 1 and ind2.ndim == 1:
return gof.Apply(self,
return gof.Apply(self,
(x, y) + inputs,
[tensor(dtype=x.type.dtype,
broadcastable=x.type.broadcastable)])
raise NotImplementedError(
'Advanced indexing increment/set of x (of dimension %i) by y'
' (of dimension %i) with these argument dimensions (%s) not'
' supported yet'
% (x.ndim, y.ndim,
','.join(str(input.ndim) for input in inputs)))
raise NotImplementedError(
'Advanced indexing increment/set of x (of dim %i) by y (of dim %i)'
' with arguments (%s) not supported yet'
% (x.ndim, y.ndim, ','.join(str(input) for input in inputs)))
def perform(self, node, inputs, out_):
# TODO: 1. opt to make this in place 2. generalize as described in
......@@ -7351,14 +7360,13 @@ class AdvancedIncSubtensor(Op):
out, = out_
if not self.inplace:
out[0] = inputs[0].copy()
else:
raise NotImplementedError('In place computation is not'
' implemented')
if self.set_instead_of_inc:
a = inputs[0].copy()
numpy.inplace_increment(a, tuple(inputs[2:]), inputs[1])
out[0] = a
out[0][inputs[2:]] = inputs[1]
else:
out[0][inputs[2:]] += inputs[1]
if (numpy.__version__ <= '1.6.1' and
out[0].size != numpy.uint32(out[0].size)):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论