提交 83bf7251 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

PEP8

上级 5ff97962
...@@ -4250,8 +4250,8 @@ def get_canonical_form_slice(theslice, length): ...@@ -4250,8 +4250,8 @@ def get_canonical_form_slice(theslice, length):
that respects the conventions imposed by python and numpy. that respects the conventions imposed by python and numpy.
In a canonical form a slice is represented by a canonical form slice, In a canonical form a slice is represented by a canonical form slice,
in which 0 <= start <= stop <= length and step > 0, and a flag which says if in which 0 <= start <= stop <= length and step > 0, and a flag which says
the resulting set of numbers needs to be reversed or not. if the resulting set of numbers needs to be reversed or not.
''' '''
if isinstance(theslice, slice): if isinstance(theslice, slice):
...@@ -4292,8 +4292,8 @@ def get_canonical_form_slice(theslice, length): ...@@ -4292,8 +4292,8 @@ def get_canonical_form_slice(theslice, length):
# Full slice. # Full slice.
return slice(0, length, 1), 1 return slice(0, length, 1), 1
if is_stop_constant and stop >= 0: if is_stop_constant and stop >= 0:
return (slice(0, switch(lt(stop, length), stop, length), 1), return (slice(0, switch(lt(stop, length), stop, length),
1) 1), 1)
stop_plus_len = stop + length stop_plus_len = stop + length
stop = switch( stop = switch(
lt(stop, 0), lt(stop, 0),
...@@ -4343,6 +4343,7 @@ def get_canonical_form_slice(theslice, length): ...@@ -4343,6 +4343,7 @@ def get_canonical_form_slice(theslice, length):
sgn_step = -1 sgn_step = -1
else: else:
is_step_neg = lt(step, 0) is_step_neg = lt(step, 0)
def switch_neg_step(a, b): def switch_neg_step(a, b):
return switch(is_step_neg, a, b) return switch(is_step_neg, a, b)
abs_step = abs(step) abs_step = abs(step)
...@@ -7120,7 +7121,8 @@ class AdvancedSubtensor1(Op): ...@@ -7120,7 +7121,8 @@ class AdvancedSubtensor1(Op):
// if all values fit. // if all values fit.
if (!PyArray_CanCastSafely(i_type, NPY_INTP)) { if (!PyArray_CanCastSafely(i_type, NPY_INTP)) {
npy_int64 min_val, max_val; npy_int64 min_val, max_val;
PyObject* py_min_val = PyArray_Min(%(i_name)s, NPY_MAXDIMS, NULL); PyObject* py_min_val = PyArray_Min(%(i_name)s, NPY_MAXDIMS,
NULL);
if (py_min_val == NULL) { if (py_min_val == NULL) {
%(fail)s; %(fail)s;
} }
...@@ -7129,7 +7131,8 @@ class AdvancedSubtensor1(Op): ...@@ -7129,7 +7131,8 @@ class AdvancedSubtensor1(Op):
if (min_val == -1 && PyErr_Occurred()) { if (min_val == -1 && PyErr_Occurred()) {
%(fail)s; %(fail)s;
} }
PyObject* py_max_val = PyArray_Max(%(i_name)s, NPY_MAXDIMS, NULL); PyObject* py_max_val = PyArray_Max(%(i_name)s, NPY_MAXDIMS,
NULL);
if (py_max_val == NULL) { if (py_max_val == NULL) {
%(fail)s; %(fail)s;
} }
...@@ -7139,7 +7142,8 @@ class AdvancedSubtensor1(Op): ...@@ -7139,7 +7142,8 @@ class AdvancedSubtensor1(Op):
%(fail)s; %(fail)s;
} }
if (min_val < NPY_MIN_INTP || max_val > NPY_MAX_INTP) { if (min_val < NPY_MIN_INTP || max_val > NPY_MAX_INTP) {
PyErr_SetString(PyExc_IndexError, "Index contains values " PyErr_SetString(PyExc_IndexError,
"Index contains values "
"that are bigger than the maximum array " "that are bigger than the maximum array "
"size on this system."); "size on this system.");
%(fail)s; %(fail)s;
...@@ -7170,7 +7174,8 @@ class AdvancedSubtensor1(Op): ...@@ -7170,7 +7174,8 @@ class AdvancedSubtensor1(Op):
} }
if (%(output_name)s != NULL) { if (%(output_name)s != NULL) {
for (; i < nd; i++) { for (; i < nd; i++) {
if (shape[i] != PyArray_DIMS(%(a_name)s)[i-PyArray_NDIM(indices)+1]) { if (shape[i] != PyArray_DIMS(%(a_name)s)[
i-PyArray_NDIM(indices)+1]) {
Py_CLEAR(%(output_name)s); Py_CLEAR(%(output_name)s);
break; break;
} }
...@@ -7178,8 +7183,8 @@ class AdvancedSubtensor1(Op): ...@@ -7178,8 +7183,8 @@ class AdvancedSubtensor1(Op):
} }
} }
} }
%(output_name)s = (PyArrayObject*)PyArray_TakeFrom(%(a_name)s, indices, 0, %(output_name)s = (PyArrayObject*)PyArray_TakeFrom(
%(output_name)s, NPY_RAISE); %(a_name)s, indices, 0, %(output_name)s, NPY_RAISE);
Py_DECREF(indices); Py_DECREF(indices);
if (%(output_name)s == NULL) %(fail)s; if (%(output_name)s == NULL) %(fail)s;
""" % locals() """ % locals()
...@@ -7189,6 +7194,7 @@ class AdvancedSubtensor1(Op): ...@@ -7189,6 +7194,7 @@ class AdvancedSubtensor1(Op):
advanced_subtensor1 = AdvancedSubtensor1() advanced_subtensor1 = AdvancedSubtensor1()
class AdvancedIncSubtensor1(Op): class AdvancedIncSubtensor1(Op):
"""Increments a subtensor using advanced slicing (list of index)""" """Increments a subtensor using advanced slicing (list of index)"""
def __init__(self, inplace=False, set_instead_of_inc=False): def __init__(self, inplace=False, set_instead_of_inc=False):
...@@ -7252,10 +7258,10 @@ class AdvancedIncSubtensor1(Op): ...@@ -7252,10 +7258,10 @@ class AdvancedIncSubtensor1(Op):
x[idx] = y x[idx] = y
else: else:
increment = inplace_increment increment = inplace_increment
if increment is None: if increment is None:
increment = self.inplace_increment1d_slow increment = self.inplace_increment1d_slow
increment(x,idx, y) increment(x, idx, y)
out[0] = x out[0] = x
...@@ -7298,7 +7304,8 @@ class AdvancedIncSubtensor1(Op): ...@@ -7298,7 +7304,8 @@ class AdvancedIncSubtensor1(Op):
return [gx, gy] + [DisconnectedType()()] * len(idx_list) return [gx, gy] + [DisconnectedType()()] * len(idx_list)
advanced_inc_subtensor1 = AdvancedIncSubtensor1() advanced_inc_subtensor1 = AdvancedIncSubtensor1()
def as_index_variable(idx): def as_index_variable(idx):
if idx is None: if idx is None:
return NoneConst return NoneConst
...@@ -7358,6 +7365,7 @@ class SliceType(gof.Type): ...@@ -7358,6 +7365,7 @@ class SliceType(gof.Type):
slicetype = SliceType() slicetype = SliceType()
class NoneTypeT(gof.Type): class NoneTypeT(gof.Type):
def filter(self, x, strict=False, allow_downcast=None): def filter(self, x, strict=False, allow_downcast=None):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论