提交 2f004fac authored 作者: Harm de Vries's avatar Harm de Vries

pep8

上级 f57a7ccb
...@@ -109,7 +109,8 @@ if 0: ...@@ -109,7 +109,8 @@ if 0:
# - JB 20100226 # - JB 20100226
def as_cuda_or_tensor_variable(x, name=None, ndim=None): def as_cuda_or_tensor_variable(x, name=None, ndim=None):
""" """
Do the same as_tensor_variable, but do not transfer the value on the gpu. Do the same as_tensor_variable,
but do not transfer the value on the gpu.
""" """
if hasattr(x, '_as_CudaNdarrayVariable'): if hasattr(x, '_as_CudaNdarrayVariable'):
# TODO: pass name and ndim arguments # TODO: pass name and ndim arguments
...@@ -516,7 +517,8 @@ def _allclose(a, b, rtol=None, atol=None): ...@@ -516,7 +517,8 @@ def _allclose(a, b, rtol=None, atol=None):
if atol is not None: if atol is not None:
atol_ = atol atol_ = atol
# Work around bug in Numpy, see http://projects.scipy.org/numpy/ticket/1684 # Work around bug in Numpy, see
# http://projects.scipy.org/numpy/ticket/1684
if str(b.dtype) in int_dtypes and (numpy.absolute(b) < 0).any(): if str(b.dtype) in int_dtypes and (numpy.absolute(b) < 0).any():
b = theano._asarray(b, dtype='float64') b = theano._asarray(b, dtype='float64')
...@@ -1289,12 +1291,6 @@ class MaxAndArgmax(Op): ...@@ -1289,12 +1291,6 @@ class MaxAndArgmax(Op):
if isinstance(axis, (tuple, list, numpy.ndarray)): if isinstance(axis, (tuple, list, numpy.ndarray)):
# List of axes: make them non-negative, and sort them # List of axes: make them non-negative, and sort them
axis = [int(a) for a in axis] axis = [int(a) for a in axis]
#if axis == list(range(-x.type.ndim, 0, 1)):
#axis = list(range(x.type.ndim))
#assert axis == list(range(x.type.ndim)), (
#"MaxAndArgmax does not support multiple"
#" axes. the max fct supports it. Got %s" % axis)
if axis == list(range(x.type.ndim)): if axis == list(range(x.type.ndim)):
axis = None axis = None
elif isinstance(axis, (int, numpy.integer)): elif isinstance(axis, (int, numpy.integer)):
...@@ -1311,7 +1307,8 @@ class MaxAndArgmax(Op): ...@@ -1311,7 +1307,8 @@ class MaxAndArgmax(Op):
assert (axis.dtype.startswith("int") or assert (axis.dtype.startswith("int") or
axis.dtype.startswith("uint")) axis.dtype.startswith("uint"))
if isinstance(axis.data, (int, numpy.integer)) or \ if isinstance(axis.data, (int, numpy.integer)) or \
(isinstance(axis.data, numpy.ndarray) and axis.data.ndim == 0): (isinstance(axis.data, numpy.ndarray) and
axis.data.ndim == 0):
axis = [int(axis.data)] axis = [int(axis.data)]
elif isinstance(axis.data, (list, numpy.ndarray)): elif isinstance(axis.data, (list, numpy.ndarray)):
axis = [int(i) for i in axis.data] axis = [int(i) for i in axis.data]
...@@ -1365,9 +1362,11 @@ class MaxAndArgmax(Op): ...@@ -1365,9 +1362,11 @@ class MaxAndArgmax(Op):
keep_axes = numpy.array([i for i in range(x.ndim) if i not in axes]) keep_axes = numpy.array([i for i in range(x.ndim) if i not in axes])
# Not-reduced axes in front # Not-reduced axes in front
transposed_x = numpy.transpose(x, numpy.concatenate((keep_axes, axes))) transposed_x = numpy.transpose(x, numpy.concatenate((keep_axes, axes)))
reshaped_x = transposed_x.reshape(transposed_x.shape[:len(keep_axes)] + (-1,)) reshaped_x = transposed_x.reshape(
transposed_x.shape[:len(keep_axes)] + (-1,))
max_idx[0] = theano._asarray(numpy.argmax(reshaped_x, axis=-1), dtype='int64') max_idx[0] = theano._asarray(numpy.argmax(reshaped_x, axis=-1),
dtype='int64')
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
x, axis = inp x, axis = inp
...@@ -1378,11 +1377,13 @@ class MaxAndArgmax(Op): ...@@ -1378,11 +1377,13 @@ class MaxAndArgmax(Op):
else: else:
assert node.inputs[1].ndim == 1 assert node.inputs[1].ndim == 1
# Fall back to perform() if there are multiple axes # Fall back to perform() if there are multiple axes
if len(node.inputs[1].data) > 1: raise NotImplementedError() if len(node.inputs[1].data) > 1:
raise NotImplementedError()
axis_code = """ axis_code = """
axis = ((dtype_%(axis)s*)PyArray_DATA(%(axis)s))[0]; axis = ((dtype_%(axis)s*)PyArray_DATA(%(axis)s))[0];
if(axis > PyArray_NDIM(%(x)s)-1 || axis < -PyArray_NDIM(%(x)s)){ if(axis > PyArray_NDIM(%(x)s)-1 || axis < -PyArray_NDIM(%(x)s)){
PyErr_SetString(PyExc_ValueError, "MaxAndArgmax, bad axis argument"); PyErr_SetString(PyExc_ValueError,
"MaxAndArgmax, bad axis argument");
%(fail)s %(fail)s
} }
""" % locals() """ % locals()
......
...@@ -2922,8 +2922,8 @@ class T_max_and_argmax(unittest.TestCase): ...@@ -2922,8 +2922,8 @@ class T_max_and_argmax(unittest.TestCase):
z[argmax] += 1 z[argmax] += 1
else: else:
for id, v in enumerate(argmax): for id, v in enumerate(argmax):
z[v * numpy.prod(data.shape[data.ndim - 1:axis:-1]) z[v * numpy.prod(data.shape[data.ndim - 1:axis:-1]) +
+ id] += 1 id] += 1
z = z.reshape(data.shape) z = z.reshape(data.shape)
assert numpy.all(max_grad_data == z) assert numpy.all(max_grad_data == z)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论