提交 76798960 authored 作者: jsalvatier's avatar jsalvatier 提交者: John Salvatier

corrected some errors

上级 9d94cece
......@@ -1759,18 +1759,10 @@ class _tensor_py_operators:
axis = i
if advanced:
if (axis is not None
and numpy.all(a == slice(None) for a in args[:axis])
and numpy.all(a == slice(None) for a in args[axis + 1:])
and isinstance(args[axis], (
numpy.ndarray,
list,
TensorVariable,
TensorConstant,
theano.tensor.sharedvar.TensorSharedVariable))):
return self.take(arg, axis)
else:
return AdvancedSubtensor()(self, *args)
return AdvancedSubtensor()(self, *args)
else:
if numpy.newaxis in args:
# None (aka np.newaxis) in numpy indexing means to add a
......@@ -7186,51 +7178,22 @@ class AdvancedIncSubtensor1(Op):
return [gx, gy] + [DisconnectedType()()] * len(idx_list)
advanced_inc_subtensor1 = AdvancedIncSubtensor1()
from itertools import groupby, chain
def simpleindex(a):
try:
return as_tensor_variable(a).ndim == 0
except:
return True
def simple_broadcastable(a, idx):
def adv_broadcastable(a, idx):
def replace_slice(v):
if isinstance(v, slice):
return slice(None,None)
if simpleindex(v):
return 0
return v
try :
return numpy.zeros( (1,)* as_tensor_variable(v).ndim, int)
except ValueError:
return v
newidx = tuple(map(replace_slice, idx))
fakeshape = [bc + 1 for bc in a.broadcastable]
retshape = numpy.empty(fakeshape)[newidx].shape
return tuple([dim == 1 for dim in retshape])
from __builtin__ import sum as concat
def concat(ls):
r = []
map(r.extend, ls)
return r
def advanced_broadcastable(a, idx):
chunks = list(groupby(idx, simpleindex))
chunks = [(s, list(c)) for s,c in chunks]
if len(chunks) > 3:
chunks = [concat(c for s, c in chunks if s), concat(c for s, c in chunks if not s)]
def getbroad((simple, c)):
if simple:
return simple_broadcastable(a, c)
else:
return as_tensor_variable(c[0]).broadcastable
return concat(map(getbroad, chunks))
class AdvancedSubtensor(Op):
"""Return a subtensor copy, using advanced indexing.
......@@ -7250,13 +7213,20 @@ class AdvancedSubtensor(Op):
def make_node(self, x, *index):
x = as_tensor_variable(x)
# should be replaced with something that includes support for None and slices
# Note (9 Jul 2012): what does this 'FIXME' mean? Possibly that the
# current implementation must be generalized? Please specify.
def as_index_variable(a):
try:
return as_tensor_variable(a)
except :
return a
idxvars = tuple(map(as_index_variable, index))
return gof.Apply(self,
(x,) + tuple(map(as_tensor_variable, index)),
(x,) + idxvars,
[tensor(dtype = x.type.dtype,
broadcastable = advanced_broadcastable(x, index) )])
broadcastable = adv_broadcastable(x, index) )])
def R_op(self, inputs, eval_points):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论