提交 a003a1c4 authored 作者: John Salvatier's avatar John Salvatier

trying out slice objects

上级 76798960
......@@ -7179,15 +7179,82 @@ class AdvancedIncSubtensor1(Op):
advanced_inc_subtensor1 = AdvancedIncSubtensor1()
def as_index_variable(idx):
if idx is None:
return NoneConst
if isinstance(idx, slice):
return make_slice(idx)
idx = as_tensor_variable(idx)
if idx.type.dtype[:3] not in ('int', 'uin'):
raise TypeError('index must be integers')
return idx
def as_int_none_variable(x):
if x is None:
return NoneConst
x = as_tensor_variable(x, ndim = 0)
if x.type.dtype[:3] not in ('int', 'uin'):
raise TypeError('index must be integers')
return x
class MakeSlice(Op):
def make_node(self, slc):
return Apply(self,
map(as_int_none_variable,[slc.start, slc.stop, slc.step]),
[Slice()])
def perform(self, node, inp, out_):
out, = out_
out[0] = slice(*inp)
def __str__(self):
return self.__class__.__name__
make_slice = MakeSlice
class Slice(gof.Type):
def filter(self, x, strict=False, allow_downcast=None):
if isinstance(x, slice):
return x
else:
raise TypeError('Expected a slice!')
def __str__(self):
return "slice"
class NoneTypeT(gof.Type):
def filter(self, x, strict=False, allow_downcast=None):
if x is None:
return x
else:
raise TypeError('Expected None!')
def __str__(self):
return "None"
NoneConst = Constant(NoneTypeT(), None, name = 'None')
def adv_broadcastable(a, idx):
def replace_slice(v):
if isinstance(v, slice):
if isinstance(v, gof.Apply):
if len(v.outputs) != 1:
raise ValueError(
"It is ambiguous which output of a multi-output Op has"
" to be fetched.", v)
else:
v = v.outputs[0]
if v is NoneConst:
return None
if isinstance(v.type, Slice):
return slice(None,None)
try :
return numpy.zeros( (1,)* as_tensor_variable(v).ndim, int)
except ValueError:
return v
return numpy.zeros( (1,)* v.ndim, int)
newidx = tuple(map(replace_slice, idx))
......@@ -7214,17 +7281,11 @@ class AdvancedSubtensor(Op):
def make_node(self, x, *index):
x = as_tensor_variable(x)
def as_index_variable(a):
try:
return as_tensor_variable(a)
except :
return a
idxvars = tuple(map(as_index_variable, index))
index = tuple(map(as_index_variable, index))
return gof.Apply(self,
(x,) + idxvars,
(x,) + index,
[tensor(dtype = x.type.dtype,
broadcastable = adv_broadcastable(x, index) )])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论