提交 0e35cc21 authored 作者: lamblin's avatar lamblin

Merge pull request #1299 from delallea/canonical_slice

Simpler graphs for canonical slices
...@@ -510,7 +510,7 @@ def get_scalar_constant_value(v): ...@@ -510,7 +510,7 @@ def get_scalar_constant_value(v):
if isinstance(v, (numpy.integer, int, float)): if isinstance(v, (numpy.integer, int, float)):
return numpy.asarray(v) return numpy.asarray(v)
def numpy_scalar(n): def numpy_scalar(data):
""" Return a scalar stored in a numpy ndarray, or raise """ Return a scalar stored in a numpy ndarray, or raise
NotScalarConstantError if the numpy ndarray is not a scalar NotScalarConstantError if the numpy ndarray is not a scalar
""" """
...@@ -526,7 +526,7 @@ def get_scalar_constant_value(v): ...@@ -526,7 +526,7 @@ def get_scalar_constant_value(v):
except Exception: except Exception:
raise NotScalarConstantError( raise NotScalarConstantError(
'v.data is non-numeric, non-scalar, or has more than one' 'v.data is non-numeric, non-scalar, or has more than one'
' unique value', n) ' unique value', data)
if isinstance(v, numpy.ndarray): if isinstance(v, numpy.ndarray):
return numpy_scalar(v) return numpy_scalar(v)
...@@ -4250,27 +4250,114 @@ def get_canonical_form_slice(theslice, length): ...@@ -4250,27 +4250,114 @@ def get_canonical_form_slice(theslice, length):
that respects the conventions imposed by python and numpy. that respects the conventions imposed by python and numpy.
In a canonical form a slice is represented by a canonical form slice, In a canonical form a slice is represented by a canonical form slice,
in which the start <= stop and step >0 and a flag which says if the in which 0 <= start <= stop <= length and step > 0, and a flag which says
resulting set of numbers needs to be reversed or not. if the resulting set of numbers needs to be reversed or not.
''' '''
if isinstance(theslice, slice): if isinstance(theslice, slice):
start = extract_constant(theslice.start) def analyze(x):
stop = extract_constant(theslice.stop) try:
step = extract_constant(theslice.step) x_constant = get_scalar_constant_value(x)
is_constant = True
except NotScalarConstantError:
x_constant = extract_constant(x)
is_constant = False
return x_constant, is_constant
start, is_start_constant = analyze(theslice.start)
stop, is_stop_constant = analyze(theslice.stop)
step, is_step_constant = analyze(theslice.step)
length, is_length_constant = analyze(length)
if step is None: if step is None:
step = 1 step = 1
defstart = switch(lt(step, 0), (length - 1), 0) # First handle the easier and common case where `step` is 1 and
defstop = switch(lt(step, 0), -1, length) # either `start` or `stop` is a range boundary. More specializations
# could be added later. This makes the resulting graph smaller than
# in the generic case below.
if step == 1:
is_start_0 = (
start in [None, 0] or
(is_start_constant and is_length_constant and
start < 0 and start + length <= 0))
is_stop_length = (
stop in [None, length, maxsize] or
(is_stop_constant and is_length_constant and
stop >= length))
if is_start_0:
# 0:stop:1
if is_stop_length:
# Full slice.
return slice(0, length, 1), 1
if is_stop_constant and stop >= 0:
return (slice(0, switch(lt(stop, length), stop, length),
1), 1)
stop_plus_len = stop + length
stop = switch(
lt(stop, 0),
# stop < 0
switch(
lt(stop_plus_len, 0),
# stop + len < 0
0,
# stop + len >= 0
stop_plus_len),
# stop >= 0: use min(stop, length)
switch(lt(stop, length), stop, length))
return slice(0, stop, 1), 1
elif is_stop_length:
# start:length:1
if is_start_constant and start >= 0:
return slice(switch(lt(start, length), start, length),
length, 1), 1
start_plus_len = start + length
start = switch(
lt(start, 0),
# start < 0
switch(
lt(start_plus_len, 0),
# start + len < 0
0,
# start + len >= 0
start_plus_len),
# start >= 0: use min(start, length)
switch(lt(start, length), start, length))
return slice(start, length, 1), 1
# This is the generic case.
if is_step_constant:
# When we know the sign of `step`, the graph can be made simpler.
assert step != 0
if step > 0:
def switch_neg_step(a, b):
return b
abs_step = step
sgn_step = 1
else:
def switch_neg_step(a, b):
return a
abs_step = -step
sgn_step = -1
else:
is_step_neg = lt(step, 0)
def switch_neg_step(a, b):
return switch(is_step_neg, a, b)
abs_step = abs(step)
sgn_step = sgn(step)
defstart = switch_neg_step(length - 1, 0)
defstop = switch_neg_step(-1, length)
if start is None: if start is None:
start = defstart start = defstart
else: else:
start = switch(lt(start, 0), start + length, start) start = switch(lt(start, 0), start + length, start)
start = switch(lt(start, 0), switch(lt(step, 0), -1, 0), start) start = switch(lt(start, 0), switch_neg_step(-1, 0), start)
start = switch(ge(start, length), start = switch(ge(start, length),
switch(lt(step, 0), (length - 1), length), switch_neg_step(length - 1, length),
start) start)
if stop in [None, maxsize]: if stop in [None, maxsize]:
# The special "maxsize" case is probably not needed here, # The special "maxsize" case is probably not needed here,
...@@ -4282,18 +4369,20 @@ def get_canonical_form_slice(theslice, length): ...@@ -4282,18 +4369,20 @@ def get_canonical_form_slice(theslice, length):
stop = switch(lt(stop, 0), -1, stop) stop = switch(lt(stop, 0), -1, stop)
stop = switch(ge(stop, length), length, stop) stop = switch(ge(stop, length), length, stop)
nw_stop = switch(lt(step, 0), (start + 1), stop) nw_stop = switch_neg_step(start + 1, stop)
slice_len = (start - stop - 1) // abs(step) + 1 slice_len = (start - stop - 1) // abs_step + 1
slice_len = switch(lt(slice_len, 0), 0, slice_len) slice_len = switch(lt(slice_len, 0), 0, slice_len)
neg_start = nw_stop - (slice_len - 1) * abs(step) - 1 neg_start = nw_stop - (slice_len - 1) * abs_step - 1
neg_start = switch(lt(neg_start, 0), (nw_stop - 1), neg_start) neg_start = switch(lt(neg_start, 0), (nw_stop - 1), neg_start)
nw_start = switch(lt(step, 0), neg_start, start) nw_start = switch_neg_step(neg_start, start)
nw_start = switch(lt(nw_start, 0), 0, nw_start) nw_start = switch(lt(nw_start, 0), 0, nw_start)
nw_stop = switch(lt(nw_stop, 0), 0, nw_stop) nw_stop = switch(lt(nw_stop, 0), 0, nw_stop)
# Ensure start <= stop.
nw_start = switch(lt(nw_start, nw_stop), nw_start, nw_stop)
nw_step = abs(step) nw_step = abs_step
if step != 1: if step != 1:
reverse = sgn(step) reverse = sgn_step
return slice(nw_start, nw_stop, nw_step), reverse return slice(nw_start, nw_stop, nw_step), reverse
else: else:
return slice(nw_start, nw_stop, nw_step), 1 return slice(nw_start, nw_stop, nw_step), 1
...@@ -4554,10 +4643,11 @@ class Subtensor(Op): ...@@ -4554,10 +4643,11 @@ class Subtensor(Op):
and (idx.step is None or idx.step == 1)): and (idx.step is None or idx.step == 1)):
outshp.append(xl) outshp.append(xl)
else: else:
cnf = get_canonical_form_slice(idx, xl) cnf = get_canonical_form_slice(idx, xl)[0]
length = ((cnf[0].stop - cnf[0].start - 1) // cnf[0].step if cnf.step == 1:
+ 1) length = cnf.stop - cnf.start
length = switch(lt(length, 0), 0, length) else:
length = (cnf.stop - cnf.start - 1) // cnf.step + 1
outshp.append(length) outshp.append(length)
i += 1 i += 1
else: else:
...@@ -7031,7 +7121,8 @@ class AdvancedSubtensor1(Op): ...@@ -7031,7 +7121,8 @@ class AdvancedSubtensor1(Op):
// if all values fit. // if all values fit.
if (!PyArray_CanCastSafely(i_type, NPY_INTP)) { if (!PyArray_CanCastSafely(i_type, NPY_INTP)) {
npy_int64 min_val, max_val; npy_int64 min_val, max_val;
PyObject* py_min_val = PyArray_Min(%(i_name)s, NPY_MAXDIMS, NULL); PyObject* py_min_val = PyArray_Min(%(i_name)s, NPY_MAXDIMS,
NULL);
if (py_min_val == NULL) { if (py_min_val == NULL) {
%(fail)s; %(fail)s;
} }
...@@ -7040,7 +7131,8 @@ class AdvancedSubtensor1(Op): ...@@ -7040,7 +7131,8 @@ class AdvancedSubtensor1(Op):
if (min_val == -1 && PyErr_Occurred()) { if (min_val == -1 && PyErr_Occurred()) {
%(fail)s; %(fail)s;
} }
PyObject* py_max_val = PyArray_Max(%(i_name)s, NPY_MAXDIMS, NULL); PyObject* py_max_val = PyArray_Max(%(i_name)s, NPY_MAXDIMS,
NULL);
if (py_max_val == NULL) { if (py_max_val == NULL) {
%(fail)s; %(fail)s;
} }
...@@ -7050,7 +7142,8 @@ class AdvancedSubtensor1(Op): ...@@ -7050,7 +7142,8 @@ class AdvancedSubtensor1(Op):
%(fail)s; %(fail)s;
} }
if (min_val < NPY_MIN_INTP || max_val > NPY_MAX_INTP) { if (min_val < NPY_MIN_INTP || max_val > NPY_MAX_INTP) {
PyErr_SetString(PyExc_IndexError, "Index contains values " PyErr_SetString(PyExc_IndexError,
"Index contains values "
"that are bigger than the maximum array " "that are bigger than the maximum array "
"size on this system."); "size on this system.");
%(fail)s; %(fail)s;
...@@ -7081,7 +7174,8 @@ class AdvancedSubtensor1(Op): ...@@ -7081,7 +7174,8 @@ class AdvancedSubtensor1(Op):
} }
if (%(output_name)s != NULL) { if (%(output_name)s != NULL) {
for (; i < nd; i++) { for (; i < nd; i++) {
if (shape[i] != PyArray_DIMS(%(a_name)s)[i-PyArray_NDIM(indices)+1]) { if (shape[i] != PyArray_DIMS(%(a_name)s)[
i-PyArray_NDIM(indices)+1]) {
Py_CLEAR(%(output_name)s); Py_CLEAR(%(output_name)s);
break; break;
} }
...@@ -7089,8 +7183,8 @@ class AdvancedSubtensor1(Op): ...@@ -7089,8 +7183,8 @@ class AdvancedSubtensor1(Op):
} }
} }
} }
%(output_name)s = (PyArrayObject*)PyArray_TakeFrom(%(a_name)s, indices, 0, %(output_name)s = (PyArrayObject*)PyArray_TakeFrom(
%(output_name)s, NPY_RAISE); %(a_name)s, indices, 0, %(output_name)s, NPY_RAISE);
Py_DECREF(indices); Py_DECREF(indices);
if (%(output_name)s == NULL) %(fail)s; if (%(output_name)s == NULL) %(fail)s;
""" % locals() """ % locals()
...@@ -7100,6 +7194,7 @@ class AdvancedSubtensor1(Op): ...@@ -7100,6 +7194,7 @@ class AdvancedSubtensor1(Op):
advanced_subtensor1 = AdvancedSubtensor1() advanced_subtensor1 = AdvancedSubtensor1()
class AdvancedIncSubtensor1(Op): class AdvancedIncSubtensor1(Op):
"""Increments a subtensor using advanced slicing (list of index)""" """Increments a subtensor using advanced slicing (list of index)"""
def __init__(self, inplace=False, set_instead_of_inc=False): def __init__(self, inplace=False, set_instead_of_inc=False):
...@@ -7163,10 +7258,10 @@ class AdvancedIncSubtensor1(Op): ...@@ -7163,10 +7258,10 @@ class AdvancedIncSubtensor1(Op):
x[idx] = y x[idx] = y
else: else:
increment = inplace_increment increment = inplace_increment
if increment is None: if increment is None:
increment = self.inplace_increment1d_slow increment = self.inplace_increment1d_slow
increment(x,idx, y) increment(x, idx, y)
out[0] = x out[0] = x
...@@ -7209,7 +7304,8 @@ class AdvancedIncSubtensor1(Op): ...@@ -7209,7 +7304,8 @@ class AdvancedIncSubtensor1(Op):
return [gx, gy] + [DisconnectedType()()] * len(idx_list) return [gx, gy] + [DisconnectedType()()] * len(idx_list)
advanced_inc_subtensor1 = AdvancedIncSubtensor1() advanced_inc_subtensor1 = AdvancedIncSubtensor1()
def as_index_variable(idx): def as_index_variable(idx):
if idx is None: if idx is None:
return NoneConst return NoneConst
...@@ -7269,6 +7365,7 @@ class SliceType(gof.Type): ...@@ -7269,6 +7365,7 @@ class SliceType(gof.Type):
slicetype = SliceType() slicetype = SliceType()
class NoneTypeT(gof.Type): class NoneTypeT(gof.Type):
def filter(self, x, strict=False, allow_downcast=None): def filter(self, x, strict=False, allow_downcast=None):
......
...@@ -6632,6 +6632,18 @@ class T_get_scalar_constant_value(unittest.TestCase): ...@@ -6632,6 +6632,18 @@ class T_get_scalar_constant_value(unittest.TestCase):
for j in range(c.value.shape[1]): for j in range(c.value.shape[1]):
assert get_scalar_constant_value(c[i, j]) == c.value[i, j] assert get_scalar_constant_value(c[i, j]) == c.value[i, j]
def test_numpy_array(self):
# Regression test for crash when called on a numpy array.
assert get_scalar_constant_value(numpy.array(3)) == 3
self.assertRaises(
tensor.NotScalarConstantError,
get_scalar_constant_value,
numpy.array([0, 1]))
self.assertRaises(
tensor.EmptyConstantError,
get_scalar_constant_value,
numpy.array([]))
class T_as_tensor_variable(unittest.TestCase): class T_as_tensor_variable(unittest.TestCase):
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论