提交 902d822b authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Fixing the canonical form of a slice following the code provided by David

上级 252c57ca
......@@ -2611,17 +2611,11 @@ def get_idx_list(inputs, idx_list):
def get_canonical_form_slice(theslice, length):
'''
Given a slice [start:stop:step] transform it into a canonical form
that has no negative values. Canonical form is defined as :
if step < 0 :
[if(stop<0,stop+length,stop):if(start<0,start+length,start):abs(step)][::-1]
else:
[if(start<0,start+length,start):if(stop<0,stop+length,stop):step]
the function will return the canonical form and either None or [::-1]
depending if the result of the canonical form needs to be reversed
that respects the conventions imposed by python and numpy.
We currently don't canonicalize variable step. We we still return an
equivalent slice, so no bug introduced.
In a canonical form a slice is represented by a canonical form slice,
in which the start <= stop and step >0 and a flag which says if the
resulting set of numbers needs to be reversed or not.
'''
def extract_constant(x):
......@@ -2629,14 +2623,15 @@ def get_canonical_form_slice(theslice, length):
This function is basically a call to tensor.get_constant_value. The
main difference is the behaviour in case of failure. While
get_constant_value raises an TypeError, this function returns x,
as a tensor ( by removing the last scalar_from_tensor ) if needed.
as a tensor ( by removing the last scalar_from_tensor ) if needed
or None if that is the value of x.
'''
try:
x = get_constant_value(x)
except:
pass
if isinstance(x, scal.ScalarVariable):
if x.owner and isinstance(x.owner.op, tensor.ScalarFromTensor):
if x.owner and isinstance(x.owner.op, ScalarFromTensor):
x = x.owner.inputs[0]
else:
x = tensor.tensor_from_scalar(x)
......@@ -2647,43 +2642,39 @@ def get_canonical_form_slice(theslice, length):
start = extract_constant(theslice.start)
stop = extract_constant(theslice.stop)
step = extract_constant(theslice.step)
# try to escape cases like sys.maxint for stop condition
if stop == sys.maxint :
stop = None
if start == 0 :
start = None
if step == 1:
step = None
if start is not None:
nw_start = switch(lt(start,0),start+length, start)
# safety guards .. this will make the graph so much more
# annoying :(
nw_start = switch(lt(nw_start,0), 0, nw_start)
if step is None:
step = 1
defstart = switch(lt(step,0), length-1, 0)
defstop = switch(lt(step,0), -1, length )
if start is None:
start = defstart
else:
nw_start = None
if stop is not None:
nw_stop = switch(lt(stop,0),stop + length,stop)
# safety guards .. this will make the graph so much more
# annoying :(
nw_stop = switch(lt(nw_stop,0), 0, nw_stop)
start = switch(lt(start,0), start + length, start)
start = switch(lt(start,0), switch(lt(step,0), -1, 0), start)
start = switch(ge(start,length)
, switch(lt(step,0),length-1,length)
, start)
if stop in [None, sys.maxint]:
stop = defstop
else:
nw_stop = None
nw_step = step
stop = switch(lt(stop,0), stop + length, stop)
stop = switch(lt(stop,0), -1, stop)
stop = switch(ge(stop,length), length,stop)
# When the step is constant and negative
# we make it positive.
# We do not for now canonicalize variable step
if type(step) is int and step < 0 :
nw_start, nw_stop = nw_stop, nw_start
nw_step = abs(step)
nw_slice = slice(nw_start,nw_stop, nw_step)
return nw_slice, None
nw_stop = switch(lt(step,0), start+1, stop )
nw_start = switch(lt(step,0), stop +1, start)
nw_step = abs(step)
if step != 1:
reverse = sgn(step)
return slice(nw_start, nw_stop, nw_step), reverse
else:
return slice(nw_start, nw_stop, nw_step), None
else:
value = extract_constant(theslice)
value = switch(lt(value,0), value+length, value)
return value, None
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论