提交 4551ce72 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

merge

...@@ -2575,6 +2575,127 @@ pprint.assign(pow, printing.OperatorPrinter('**', 1, 'right')) ...@@ -2575,6 +2575,127 @@ pprint.assign(pow, printing.OperatorPrinter('**', 1, 'right'))
# View Operations # View Operations
########################## ##########################
##########
# Helpful functions to deal with Subtensor and IncSubtensor
##########
def get_idx_list(inputs, idx_list):
'''
Given a list of inputs to the subtensor and its idx_list reorders
the inputs according to the idx list to get the right values
'''
# The subtensor (or idx_list) does not depend on the inputs.
indices = list(reversed(list(inputs[1:])))
if len(indices) == 0:
return tuple(idx_list)
# General case
def convert(entry):
if isinstance(entry, gof.Type):
return indices.pop()
elif isinstance(entry, slice):
return slice(convert(entry.start),
convert(entry.stop),
convert(entry.step))
else:
return entry
cdata = tuple(map(convert, idx_list))
return cdata
def get_canonical_form_slice(theslice, length):
'''
Given a slice [start:stop:step] transform it into a canonical form
that has no negative values. Canonical form is defined as :
if step < 0 :
[if(stop<0,stop+length,stop):if(start<0,start+length,start):abs(step)][::-1]
else:
[if(start<0,start+length,start):if(stop<0,stop+length,stop):step]
the function will return the canonical form and either None or [::-1]
depending if the result of the canonical form needs to be reversed
'''
def extract_constant(x):
'''
This function is basically a call to tensor.get_constant_value. The
main difference is the behaviour in case of failure. While
get_constant_value raises an TypeError, this function returns x,
as a tensor ( by removing the last scalar_from_tensor ) if needed.
'''
try:
x = get_constant_value(x)
except:
pass
if isinstance(x, scal.ScalarVariable):
if x.owner and isinstance(x.owner.op, tensor.ScalarFromTensor):
x = x.owner.inputs[0]
else:
x = tensor.tensor_from_scalar(x)
return x
if isinstance(theslice,slice):
start = extract_constant(theslice.start)
stop = extract_constant(theslice.stop)
step = extract_constant(theslice.step)
# try to escape cases like sys.maxint for stop condition
if stop == sys.maxint :
stop = None
if start == 0 :
start = None
if step == 1:
step = None
if type(step) is int and step < 0 :
if stop is not None:
nw_start = switch(lt(stop,0), stop+length, stop)
# safety guards .. this will make the graph so much more
# annoying :(
nw_start = switch(lt(nw_start,0), 0, nw_start)
else:
nw_start = None
if start is not None:
nw_stop = switch(lt(start,0), start+length, start)
# safety guards .. this will make the graph so much more
# annoying :(
nw_stop = switch(lt(nw_stop, 0), 0, nw_stop)
else:
nw_stop = None
nw_step = abs(step)
nw_slice = slice(nw_start,nw_stop, nw_step)
return nw_slice, slice(None,None,-1)
else:
if start is not None:
nw_start = switch(lt(start,0),start+length, start)
# safety guards .. this will make the graph so much more
# annoying :(
nw_start = switch(lt(nw_start,0), 0, nw_start)
else:
nw_start = None
if stop is not None:
nw_stop = switch(lt(stop,0),stop + length,stop)
# safety guards .. this will make the graph so much more
# annoying :(
nw_stop = switch(lt(nw_stop,0), 0, nw_stop)
else:
nw_stop = None
nw_step = step
nw_slice = slice(nw_start,nw_stop, nw_step)
return nw_slice, None
else:
value = extract_constant(theslice)
value = switch(lt(value,0), value+length, value)
return value, None
def transpose(x, **kwargs): def transpose(x, **kwargs):
dims = range(x.ndim-1, -1, -1) dims = range(x.ndim-1, -1, -1)
return DimShuffle(x.broadcastable, dims, inplace=True)(tensor_copy(x)) return DimShuffle(x.broadcastable, dims, inplace=True)(tensor_copy(x))
...@@ -2718,30 +2839,12 @@ class Subtensor(Op): ...@@ -2718,30 +2839,12 @@ class Subtensor(Op):
out[0] = numpy.asarray(x.__getitem__(self.perform_cache_cdata)) out[0] = numpy.asarray(x.__getitem__(self.perform_cache_cdata))
return return
indices = list(reversed(inputs[1:])) cdata = get_idx_list(inputs, self.idx_list)
if len(cdata) == 1:
# The subtensor (or idx_list) does not depend on the inputs. cdata = cdata[0]
# (first call caches cdata here) # (first call caches cdata here)
if len(indices) == 0: if len(inputs[1:]) == 0:
cdata = tuple(self.idx_list)
if len(cdata) == 1:
cdata = cdata[0]
self.perform_cache_cdata = cdata self.perform_cache_cdata = cdata
# General case
else:
def convert(entry):
if isinstance(entry, gof.Type):
return indices.pop()
elif isinstance(entry, slice):
return slice(convert(entry.start),
convert(entry.stop),
convert(entry.step))
else:
return entry
cdata = tuple(map(convert, self.idx_list))
if len(cdata) == 1:
cdata = cdata[0]
out[0] = numpy.asarray(x.__getitem__(cdata)) out[0] = numpy.asarray(x.__getitem__(cdata))
...@@ -2749,7 +2852,8 @@ class Subtensor(Op): ...@@ -2749,7 +2852,8 @@ class Subtensor(Op):
xshp = shapes[0] xshp = shapes[0]
assert len(xshp) == node.inputs[0].ndim assert len(xshp) == node.inputs[0].ndim
outshp = [] outshp = []
padded = self.idx_list + [slice(None, None, None)] * (len(xshp) - len(self.idx_list)) actual_idx_list = list(get_idx_list(node.inputs, self.idx_list))
padded = actual_idx_list + [slice(None, None, None)] * (len(xshp) - len(self.idx_list))
i = 0 i = 0
shape_i = node.env.shape_feature.shape_i shape_i = node.env.shape_feature.shape_i
for idx, xl in zip(padded, xshp): for idx, xl in zip(padded, xshp):
...@@ -2758,11 +2862,23 @@ class Subtensor(Op): ...@@ -2758,11 +2862,23 @@ class Subtensor(Op):
# the shape will be xl # the shape will be xl
if (idx.start is None or idx.start == 0)\ if (idx.start is None or idx.start == 0)\
and (idx.stop is None or idx.stop == sys.maxint)\ and (idx.stop is None or idx.stop == sys.maxint)\
and (idx.step is None or idx.step == 1): and (idx.step is None or abs(idx.step) == 1):
outshp.append(xl) outshp.append(xl)
else: else:
# Not implemented yet cnf = get_canonical_form_slice(idx, xl)
outshp.append(shape_i(i)(node.outputs[0])) if cnf[0].stop not in [None, sys.maxint]:
length = cnf[0].stop
else:
length = xl
if cnf[0].start not in [None,0]:
length = length - cnf[0].start
length = switch(lt(length,0), 0, length)
if cnf[0].step not in [None, 1]:
# any more elegant way of doing this??
length = cast(
ceil(length / cast(cnf[0].step,'float32')),'int64')
outshp.append(length)
i += 1 i += 1
else: else:
# That dimension is dropped # That dimension is dropped
......
...@@ -1689,6 +1689,18 @@ class T_subtensor(unittest.TestCase): ...@@ -1689,6 +1689,18 @@ class T_subtensor(unittest.TestCase):
self.failUnless(isinstance(topo_[0].op, self.adv_sub1)) self.failUnless(isinstance(topo_[0].op, self.adv_sub1))
self.assertRaises(IndexError, f) self.assertRaises(IndexError, f)
def test_shape_i(self):
data = self.shared(numpy.zeros((50,50,50,50),dtype ='int32'))
for slices in [ (slice(2,10,2),slice(None,None,None),slice(None,None,-1)),
(slice(-5,10,1),slice(10,2,-1),slice(4,None,None)),
(slice(3,-10,1),slice(10,15,8)) ]:
sliced_data = data[slices]
f = function([], sliced_data.shape )
assert numpy.all(f() == data.get_value()[slices].shape)
assert theano.tensor.Subtensor not in [ x.op for x in
f.maker.env.toposort() ]
def grad_list_(self, idxs, data): def grad_list_(self, idxs, data):
n = self.shared(data) n = self.shared(data)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论