提交 902d822b authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Fixing the canonical form of a slice following the code provided by David

上级 252c57ca
...@@ -2611,17 +2611,11 @@ def get_idx_list(inputs, idx_list): ...@@ -2611,17 +2611,11 @@ def get_idx_list(inputs, idx_list):
def get_canonical_form_slice(theslice, length): def get_canonical_form_slice(theslice, length):
''' '''
Given a slice [start:stop:step] transform it into a canonical form Given a slice [start:stop:step] transform it into a canonical form
that has no negative values. Canonical form is defined as : that respects the conventions imposed by python and numpy.
if step < 0 :
[if(stop<0,stop+length,stop):if(start<0,start+length,start):abs(step)][::-1]
else:
[if(start<0,start+length,start):if(stop<0,stop+length,stop):step]
the function will return the canonical form and either None or [::-1]
depending if the result of the canonical form needs to be reversed
We currently don't canonicalize variable step. We we still return an In a canonical form a slice is represented by a canonical form slice,
equivalent slice, so no bug introduced. in which the start <= stop and step >0 and a flag which says if the
resulting set of numbers needs to be reversed or not.
''' '''
def extract_constant(x): def extract_constant(x):
...@@ -2629,14 +2623,15 @@ def get_canonical_form_slice(theslice, length): ...@@ -2629,14 +2623,15 @@ def get_canonical_form_slice(theslice, length):
This function is basically a call to tensor.get_constant_value. The This function is basically a call to tensor.get_constant_value. The
main difference is the behaviour in case of failure. While main difference is the behaviour in case of failure. While
get_constant_value raises an TypeError, this function returns x, get_constant_value raises an TypeError, this function returns x,
as a tensor ( by removing the last scalar_from_tensor ) if needed. as a tensor ( by removing the last scalar_from_tensor ) if needed
or None if that is the value of x.
''' '''
try: try:
x = get_constant_value(x) x = get_constant_value(x)
except: except:
pass pass
if isinstance(x, scal.ScalarVariable): if isinstance(x, scal.ScalarVariable):
if x.owner and isinstance(x.owner.op, tensor.ScalarFromTensor): if x.owner and isinstance(x.owner.op, ScalarFromTensor):
x = x.owner.inputs[0] x = x.owner.inputs[0]
else: else:
x = tensor.tensor_from_scalar(x) x = tensor.tensor_from_scalar(x)
...@@ -2647,43 +2642,39 @@ def get_canonical_form_slice(theslice, length): ...@@ -2647,43 +2642,39 @@ def get_canonical_form_slice(theslice, length):
start = extract_constant(theslice.start) start = extract_constant(theslice.start)
stop = extract_constant(theslice.stop) stop = extract_constant(theslice.stop)
step = extract_constant(theslice.step) step = extract_constant(theslice.step)
# try to escape cases like sys.maxint for stop condition if step is None:
if stop == sys.maxint : step = 1
stop = None
if start == 0 : defstart = switch(lt(step,0), length-1, 0)
start = None defstop = switch(lt(step,0), -1, length )
if step == 1: if start is None:
step = None start = defstart
if start is not None:
nw_start = switch(lt(start,0),start+length, start)
# safety guards .. this will make the graph so much more
# annoying :(
nw_start = switch(lt(nw_start,0), 0, nw_start)
else: else:
nw_start = None start = switch(lt(start,0), start + length, start)
if stop is not None: start = switch(lt(start,0), switch(lt(step,0), -1, 0), start)
nw_stop = switch(lt(stop,0),stop + length,stop) start = switch(ge(start,length)
# safety guards .. this will make the graph so much more , switch(lt(step,0),length-1,length)
# annoying :( , start)
nw_stop = switch(lt(nw_stop,0), 0, nw_stop) if stop in [None, sys.maxint]:
stop = defstop
else: else:
nw_stop = None stop = switch(lt(stop,0), stop + length, stop)
nw_step = step stop = switch(lt(stop,0), -1, stop)
stop = switch(ge(stop,length), length,stop)
# When the step is constant and negative nw_stop = switch(lt(step,0), start+1, stop )
# we make it positive. nw_start = switch(lt(step,0), stop +1, start)
# We do not for now canonicalize variable step
if type(step) is int and step < 0 :
nw_start, nw_stop = nw_stop, nw_start
nw_step = abs(step)
nw_slice = slice(nw_start,nw_stop, nw_step)
return nw_slice, None
nw_step = abs(step)
if step != 1:
reverse = sgn(step)
return slice(nw_start, nw_stop, nw_step), reverse
else:
return slice(nw_start, nw_stop, nw_step), None
else: else:
value = extract_constant(theslice) value = extract_constant(theslice)
value = switch(lt(value,0), value+length, value) value = switch(lt(value,0), value+length, value)
return value, None return value, None
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论