提交 9d94cece authored 作者: jsalvatier's avatar jsalvatier 提交者: John Salvatier

initial try at making advanced subtensor and subtensor inc work in the general case

上级 69e23761
...@@ -7187,7 +7187,51 @@ class AdvancedIncSubtensor1(Op): ...@@ -7187,7 +7187,51 @@ class AdvancedIncSubtensor1(Op):
advanced_inc_subtensor1 = AdvancedIncSubtensor1() advanced_inc_subtensor1 = AdvancedIncSubtensor1()
from itertools import groupby, chain
def simpleindex(a):
try:
return as_tensor_variable(a).ndim == 0
except:
return True
def simple_broadcastable(a, idx):
def replace_slice(v):
if isinstance(v, slice):
return slice(None,None)
if simpleindex(v):
return 0
return v
newidx = tuple(map(replace_slice, idx))
fakeshape = [bc + 1 for bc in a.broadcastable]
retshape = numpy.empty(fakeshape)[newidx].shape
return tuple([dim == 1 for dim in retshape])
from __builtin__ import sum as concat
def concat(ls):
r = []
map(r.extend, ls)
return r
def advanced_broadcastable(a, idx):
chunks = list(groupby(idx, simpleindex))
chunks = [(s, list(c)) for s,c in chunks]
if len(chunks) > 3:
chunks = [concat(c for s, c in chunks if s), concat(c for s, c in chunks if not s)]
def getbroad((simple, c)):
if simple:
return simple_broadcastable(a, c)
else:
return as_tensor_variable(c[0]).broadcastable
return concat(map(getbroad, chunks))
class AdvancedSubtensor(Op): class AdvancedSubtensor(Op):
"""Return a subtensor copy, using advanced indexing. """Return a subtensor copy, using advanced indexing.
""" """
...@@ -7204,37 +7248,16 @@ class AdvancedSubtensor(Op): ...@@ -7204,37 +7248,16 @@ class AdvancedSubtensor(Op):
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
def make_node(self, x, *inputs): def make_node(self, x, *index):
x = as_tensor_variable(x) x = as_tensor_variable(x)
# FIXME # should be replaced with something that includes support for None and slices
# Note (9 Jul 2012): what does this 'FIXME' mean? Possibly that the # Note (9 Jul 2012): what does this 'FIXME' mean? Possibly that the
# current implementation must be generalized? Please specify. # current implementation must be generalized? Please specify.
if x.ndim == 2 and len(inputs) == 2: return gof.Apply(self,
ind1 = as_tensor_variable(inputs[0]) (x,) + tuple(map(as_tensor_variable, index)),
ind2 = as_tensor_variable(inputs[1]) [tensor(dtype = x.type.dtype,
if (not (ind1.type.dtype.startswith('int') or broadcastable = advanced_broadcastable(x, index) )])
ind1.type.dtype.startswith('uint'))):
raise TypeError(
'the indices into a matrix must be int or uint. It is ',
ind1.type.dtype)
if (not (ind2.type.dtype.startswith('int') or
ind2.type.dtype.startswith('uint'))):
raise TypeError(
'the indices into a matrix must be int or uint. It is ',
ind2.type.dtype)
if ind1.ndim == 1 and ind2.ndim == 1:
return gof.Apply(self,
(x, ind1, ind2),
[tensor(dtype=x.type.dtype,
broadcastable=[False])])
raise NotImplementedError(
'Advanced indexing of x (of dimension %i) with these argument'
' dimensions (%s) not supported yet'
% (x.ndim, ','.join(str(input.ndim) for input in inputs)))
raise NotImplementedError(
'Advanced indexing of x with arguments (%s) not supported yet'
% ','.join(str(input) for input in inputs))
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
if eval_points[0] is None: if eval_points[0] is None:
...@@ -7326,24 +7349,10 @@ class AdvancedIncSubtensor(Op): ...@@ -7326,24 +7349,10 @@ class AdvancedIncSubtensor(Op):
x = as_tensor_variable(x) x = as_tensor_variable(x)
y = as_tensor_variable(y) y = as_tensor_variable(y)
if x.ndim == 2 and y.ndim == 1 and len(inputs) == 2: return gof.Apply(self,
ind1 = as_tensor_variable(inputs[0])
ind2 = as_tensor_variable(inputs[1])
if ind1.ndim == 1 and ind2.ndim == 1:
return gof.Apply(self,
(x, y) + inputs, (x, y) + inputs,
[tensor(dtype=x.type.dtype, [tensor(dtype=x.type.dtype,
broadcastable=x.type.broadcastable)]) broadcastable=x.type.broadcastable)])
raise NotImplementedError(
'Advanced indexing increment/set of x (of dimension %i) by y'
' (of dimension %i) with these argument dimensions (%s) not'
' supported yet'
% (x.ndim, y.ndim,
','.join(str(input.ndim) for input in inputs)))
raise NotImplementedError(
'Advanced indexing increment/set of x (of dim %i) by y (of dim %i)'
' with arguments (%s) not supported yet'
% (x.ndim, y.ndim, ','.join(str(input) for input in inputs)))
def perform(self, node, inputs, out_): def perform(self, node, inputs, out_):
# TODO: 1. opt to make this in place 2. generalize as described in # TODO: 1. opt to make this in place 2. generalize as described in
...@@ -7351,14 +7360,13 @@ class AdvancedIncSubtensor(Op): ...@@ -7351,14 +7360,13 @@ class AdvancedIncSubtensor(Op):
out, = out_ out, = out_
if not self.inplace: if not self.inplace:
out[0] = inputs[0].copy()
else: a = inputs[0].copy()
raise NotImplementedError('In place computation is not' numpy.inplace_increment(a, tuple(inputs[2:]), inputs[1])
' implemented') out[0] = a
if self.set_instead_of_inc:
out[0][inputs[2:]] = inputs[1] out[0][inputs[2:]] = inputs[1]
else: else:
out[0][inputs[2:]] += inputs[1]
if (numpy.__version__ <= '1.6.1' and if (numpy.__version__ <= '1.6.1' and
out[0].size != numpy.uint32(out[0].size)): out[0].size != numpy.uint32(out[0].size)):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论