提交 9c07bd65 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Added two function to help deal with subtensors. One function defines the

slices given the idx_list and the list of node inputs. The other function transforms a slice into a canonical form ( where all entries are positives), which makes dealing with slices much easier. I've also refractored the perform of the Subtensor to use the first function ( since is mostly copy paste from that code anyway)
上级 9b238d6b
...@@ -2575,6 +2575,127 @@ pprint.assign(pow, printing.OperatorPrinter('**', 1, 'right')) ...@@ -2575,6 +2575,127 @@ pprint.assign(pow, printing.OperatorPrinter('**', 1, 'right'))
# View Operations # View Operations
########################## ##########################
##########
# Helpful functions to deal with Subtensor and IncSubtensor
##########
def get_idx_list(inputs, idx_list):
'''
Given a list of inputs to the subtensor and its idx_list reorders
the inputs according to the idx list to get the right values
'''
# The subtensor (or idx_list) does not depend on the inputs.
indices = list(reversed(list(inputs[1:])))
if len(indices) == 0:
return tuple(idx_list)
# General case
def convert(entry):
if isinstance(entry, gof.Type):
return indices.pop()
elif isinstance(entry, slice):
return slice(convert(entry.start),
convert(entry.stop),
convert(entry.step))
else:
return entry
cdata = tuple(map(convert, idx_list))
return cdata
def get_canonical_form_slice(theslice, length):
'''
Given a slice [start:stop:step] transform it into a canonical form
that has no negative values. Canonical form is defined as :
if step < 0 :
[if(stop<0,stop+length,stop):if(start<0,start+length,start):abs(step)][::-1]
else:
[if(start<0,start+length,start):if(stop<0,stop+length,stop):step]
the function will return the canonical form and either None or [::-1]
depending if the result of the canonical form needs to be reversed
'''
def extract_constant(x):
'''
This function is basically a call to tensor.get_constant_value. The
main difference is the behaviour in case of failure. While
get_constant_value raises an TypeError, this function returns x,
as a tensor ( by removing the last scalar_from_tensor ) if needed.
'''
try:
x = get_constant_value(x)
except:
pass
if isinstance(x, scal.ScalarVariable):
if x.owner and isinstance(x.owner.op, tensor.ScalarFromTensor):
x = x.owner.inputs[0]
else:
x = tensor.tensor_from_scalar(x)
return x
if isinstance(theslice,slice):
start = extract_constant(theslice.start)
stop = extract_constant(theslice.stop)
step = extract_constant(theslice.step)
# try to escape cases like sys.maxint for stop condition
if stop == sys.maxint :
stop = None
if start == 0 :
start = None
if step == 1:
step = None
if type(step) is int and step < 0 :
if stop is not None:
nw_start = switch(lt(stop,0), stop+length, stop)
# safety guards .. this will make the graph so much more
# annoying :(
nw_start = switch(lt(nw_start,0), 0, nw_start)
else:
nw_start = None
if start is not None:
nw_stop = switch(lt(start,0), start+length, start)
# safety guards .. this will make the graph so much more
# annoying :(
nw_stop = switch(lt(nw_stop, 0), 0, nw_stop)
else:
nw_stop = None
nw_step = abs(step)
nw_slice = slice(nw_start,nw_stop, nw_step)
return nw_slice, slice(None,None,-1)
else:
if start is not None:
nw_start = switch(lt(start,0),start+length, start)
# safety guards .. this will make the graph so much more
# annoying :(
nw_start = switch(lt(nw_start,0), 0, nw_start)
else:
nw_start = None
if stop is not None:
nw_stop = switch(lt(stop,0),stop + length,stop)
# safety guards .. this will make the graph so much more
# annoying :(
nw_stop = switch(lt(nw_stop,0), 0, nw_stop)
else:
nw_stop = None
nw_step = step
nw_slice = slice(nw_start,nw_stop, nw_step)
return nw_slice, None
else:
value = extract_constant(theslice)
value = switch(lt(value,0), value+length, value)
return value, None
def transpose(x, **kwargs): def transpose(x, **kwargs):
dims = range(x.ndim-1, -1, -1) dims = range(x.ndim-1, -1, -1)
return DimShuffle(x.broadcastable, dims, inplace=True)(tensor_copy(x)) return DimShuffle(x.broadcastable, dims, inplace=True)(tensor_copy(x))
...@@ -2718,30 +2839,12 @@ class Subtensor(Op): ...@@ -2718,30 +2839,12 @@ class Subtensor(Op):
out[0] = numpy.asarray(x.__getitem__(self.perform_cache_cdata)) out[0] = numpy.asarray(x.__getitem__(self.perform_cache_cdata))
return return
indices = list(reversed(inputs[1:])) cdata = get_idx_list(inputs, self.idx_list)
if len(cdata) == 1:
# The subtensor (or idx_list) does not depend on the inputs. cdata = cdata[0]
# (first call caches cdata here) # (first call caches cdata here)
if len(indices) == 0: if len(inputs[1:]) == 0:
cdata = tuple(self.idx_list)
if len(cdata) == 1:
cdata = cdata[0]
self.perform_cache_cdata = cdata self.perform_cache_cdata = cdata
# General case
else:
def convert(entry):
if isinstance(entry, gof.Type):
return indices.pop()
elif isinstance(entry, slice):
return slice(convert(entry.start),
convert(entry.stop),
convert(entry.step))
else:
return entry
cdata = tuple(map(convert, self.idx_list))
if len(cdata) == 1:
cdata = cdata[0]
out[0] = numpy.asarray(x.__getitem__(cdata)) out[0] = numpy.asarray(x.__getitem__(cdata))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论