提交 a4f0dced authored 作者: Iban Harlouchet's avatar Iban Harlouchet

flake8 tensor/elemwise.py

上级 2c125069
......@@ -11,7 +11,7 @@ from six import iteritems
from six.moves import xrange
from theano.gof import Apply, Op, OpenMPOp
from theano import scalar
from theano.scalar import Scalar, get_scalar_type
from theano.scalar import get_scalar_type
from theano.printing import pprint
from theano.tensor.utils import hash_from_dict
from theano.gradient import DisconnectedType
......@@ -50,7 +50,7 @@ def TensorConstant(*inputs, **kwargs):
##################
### DimShuffle ###
# DimShuffle #
##################
class DimShuffle(Op):
......@@ -219,12 +219,11 @@ class DimShuffle(Op):
and self.input_broadcastable == other.input_broadcastable
def _rehash(self):
self._hashval = (
hash(type(self).__name__)
^ hash(type(self).__module__)
^ hash(self.inplace)
^ hash(self.new_order)
^ hash(self.input_broadcastable))
self._hashval = (hash(type(self).__name__) ^
hash(type(self).__module__) ^
hash(self.inplace) ^
hash(self.new_order) ^
hash(self.input_broadcastable))
def __hash__(self):
return self._hashval
......@@ -286,7 +285,8 @@ class DimShuffle(Op):
nd_out = len(self.new_order)
check_input_nd = [('if (PyArray_NDIM(%(input)s) != ' + str(nd_in) + ')'
'{PyErr_SetString(PyExc_NotImplementedError, "input nd"); %(fail)s;}')]
'{PyErr_SetString(PyExc_NotImplementedError, '
'"input nd"); %(fail)s;}')]
clear_output = ['if (%(res)s) {Py_XDECREF(%(res)s);}']
......@@ -296,8 +296,10 @@ class DimShuffle(Op):
get_base = [
'{ PyArrayObject * %(basename)s = %(input)s', 'Py_INCREF((PyObject*)%(basename)s)']
else:
get_base = [('{ PyArrayObject * %(basename)s = (PyArrayObject*)PyArray_FromAny((PyObject*)%(input)s, NULL,'
'0, 0, NPY_ARRAY_ALIGNED|NPY_ARRAY_ENSURECOPY, NULL)')]
get_base = [('{ PyArrayObject * %(basename)s = '
'(PyArrayObject*)PyArray_FromAny((PyObject*)%(input)s,'
' NULL, 0, 0, NPY_ARRAY_ALIGNED|NPY_ARRAY_ENSURECOPY,'
' NULL)')]
shape_statements = ['npy_intp dimensions[%i]' % nd_out]
for i, o in enumerate(self.new_order):
......@@ -312,9 +314,12 @@ class DimShuffle(Op):
# set the strides of the non-broadcasted dimensions
for i, o in enumerate(self.new_order):
if o != 'x':
strides_statements += [('strides[' + str(i)
+ '] = PyArray_DIMS(%(basename)s)[' + str(o)
+ '] == 1? 0 : PyArray_STRIDES(%(basename)s)[' + str(o) + ']')]
strides_statements += [('strides[' + str(i) +
'] = PyArray_DIMS(%(basename)s)[' +
str(o) +
'] == 1? 0 : '
'PyArray_STRIDES(%(basename)s)[' +
str(o) + ']')]
else:
strides_statements += [('strides[' + str(i) + '] = 0')]
......@@ -360,12 +365,12 @@ PyArray_SetBaseObject(%(res)s, (PyObject*)%(basename)s);
"""
'}']
full_code = statements(check_input_nd
+ clear_output
+ get_base
+ shape_statements
+ strides_statements
+ close_bracket)
full_code = statements(check_input_nd +
clear_output +
get_base +
shape_statements +
strides_statements +
close_bracket)
if 0:
print('C_CODE')
......@@ -432,7 +437,7 @@ pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, DimShuffle),
################
### Elemwise ###
# Elemwise #
################
class Elemwise(OpenMPOp):
......@@ -518,7 +523,8 @@ class Elemwise(OpenMPOp):
self.nfunc = getattr(numpy, self.nfunc_spec[0])
elif self.scalar_op.nin > 0:
self.ufunc = numpy.frompyfunc(self.scalar_op.impl,
self.scalar_op.nin, self.scalar_op.nout)
self.scalar_op.nin,
self.scalar_op.nout)
self._rehash()
def make_node(self, *inputs):
......@@ -557,7 +563,8 @@ class Elemwise(OpenMPOp):
# it is multiplied by nout because Elemwise supports multiple outputs
# (nout of them)
out_broadcastables = [[all(bcast)
for bcast in izip(*[input.type.broadcastable
for bcast in
izip(*[input.type.broadcastable
for input in inputs])]] * shadow.nout
# inplace_pattern maps output idx -> input idx
......@@ -579,8 +586,8 @@ class Elemwise(OpenMPOp):
([i.type.dtype for i in inputs], out_dtypes, inplace_pattern)))
outputs = [TensorType(dtype=dtype, broadcastable=broadcastable)()
for dtype, broadcastable in izip(out_dtypes, out_broadcastables)
]
for dtype, broadcastable in izip(out_dtypes,
out_broadcastables)]
return Apply(self, inputs, outputs)
def __eq__(self, other):
......@@ -589,8 +596,8 @@ class Elemwise(OpenMPOp):
other_items = list(other.inplace_pattern.items())
items.sort()
other_items.sort()
rval = ((self.scalar_op == other.scalar_op)
and (items == other_items))
rval = ((self.scalar_op == other.scalar_op) and
(items == other_items))
return rval
return False
......@@ -714,7 +721,7 @@ class Elemwise(OpenMPOp):
# close for
sr = Sum(axis=to_sum)(rval[i])
sr = sr.dimshuffle(shuffle)
#sr = DimShuffle(sr.type.broadcastable, shuffle)(sr)
# sr = DimShuffle(sr.type.broadcastable, shuffle)(sr)
rval[i] = sr
# close if
# close for
......@@ -787,7 +794,6 @@ class Elemwise(OpenMPOp):
# should be disabled.
super(Elemwise, self).perform(node, inputs, output_storage)
maxsize = max(len(input.shape) for input in inputs)
for dims in izip(*[list(zip(input.shape, sinput.type.broadcastable))
for input, sinput in zip(inputs, node.inputs)]):
if max(d for d, b in dims) != 1 and (1, False) in dims:
......@@ -1192,15 +1198,16 @@ class Elemwise(OpenMPOp):
return self.scalar_op.c_support_code()
def c_support_code_apply(self, node, nodename):
support_code = self.scalar_op.c_support_code_apply(node,
nodename + '_scalar_')
support_code = self.scalar_op.c_support_code_apply(node, nodename +
'_scalar_')
return support_code
def c_code_cache_version_apply(self, node):
version = [12] # the version corresponding to the c code in this Op
# now we insert versions for the ops on which we depend...
scalar_node = Apply(self.scalar_op,
scalar_node = Apply(
self.scalar_op,
[get_scalar_type(dtype=input.type.dtype).make_variable()
for input in node.inputs],
[get_scalar_type(dtype=output.type.dtype).make_variable()
......@@ -1233,7 +1240,7 @@ class Elemwise(OpenMPOp):
################
### CAReduce ###
# CAReduce #
################
class CAReduce(Op):
......@@ -1325,8 +1332,8 @@ class CAReduce(Op):
if self.axis is not None:
for axis in self.axis:
if (axis >= input.type.ndim
or (axis < 0 and abs(axis) > input.type.ndim)):
if (axis >= input.type.ndim or
(axis < 0 and abs(axis) > input.type.ndim)):
raise ValueError((
'Not enough dimensions on %s to reduce on axis %s'
% (input, axis)))
......@@ -1366,9 +1373,9 @@ class CAReduce(Op):
self.set_ufunc(self.scalar_op)
def __eq__(self, other):
return (type(self) == type(other)
and self.scalar_op == other.scalar_op
and self.axis == other.axis)
return (type(self) == type(other) and
self.scalar_op == other.scalar_op and
self.axis == other.axis)
def __hash__(self):
if self.axis is None:
......@@ -1420,8 +1427,8 @@ class CAReduce(Op):
# was built with "frompyfunc". We need to find out if we
# are in one of these cases (only "object" is supported in
# the output).
if ((self.ufunc.ntypes == 1)
and (self.ufunc.types[0][-1] == 'O')):
if ((self.ufunc.ntypes == 1) and
(self.ufunc.types[0][-1] == 'O')):
variable = self.ufunc.reduce(variable, dimension,
dtype='object')
else:
......@@ -1570,8 +1577,7 @@ for(int i=0;i<PyArray_NDIM(%(iname)s);i++){
raise TypeError(
"The CAReduce.scalar_op must have an identity field.")
task0_decl = (
"%(dtype)s& %(name)s_i = *%(name)s_iter;\n"
task0_decl = ("%(dtype)s& %(name)s_i = *%(name)s_iter;\n"
"%(name)s_i = %(identity)s;"
% dict(dtype=adtype, name=aname, identity=identity))
......@@ -1579,8 +1585,7 @@ for(int i=0;i<PyArray_NDIM(%(iname)s);i++){
% dict(dtype=idtype, name=inames[0]))
task1_code = self.scalar_op.c_code(
Apply(
self.scalar_op,
Apply(self.scalar_op,
[get_scalar_type(dtype=input.type.dtype).make_variable()
for input in (node.inputs * 2)],
[get_scalar_type(dtype=output.type.dtype).make_variable()
......@@ -1600,11 +1605,10 @@ for(int i=0;i<PyArray_NDIM(%(iname)s);i++){
if len(axis) == 1:
all_code = [("", "")] * nnested + [(task0_decl, code1), ""]
else:
all_code = (
[("", "")] * nnested
+ [(task0_decl, "")]
+ [("", "")] * (len(axis) - 2)
+ [("", code1), ""])
all_code = ([("", "")] * nnested +
[(task0_decl, "")] +
[("", "")] * (len(axis) - 2) +
[("", code1), ""])
else:
all_code = [task0_decl + code1]
loop = cgen.make_loop_careduce(
......@@ -1632,7 +1636,8 @@ for(int i=0;i<PyArray_NDIM(%(iname)s);i++){
version = [5] # the version corresponding to the c code in this Op
# now we insert versions for the ops on which we depend...
scalar_node = Apply(self.scalar_op,
scalar_node = Apply(
self.scalar_op,
[get_scalar_type(dtype=input.type.dtype).make_variable()
for input in node.inputs],
[get_scalar_type(dtype=output.type.dtype).make_variable()
......@@ -1760,9 +1765,9 @@ class CAReduceDtype(CAReduce):
self.acc_dtype = acc_dtype
def __eq__(self, other):
return (CAReduce.__eq__(self, other)
and self.dtype == other.dtype
and self.acc_dtype == other.acc_dtype)
return (CAReduce.__eq__(self, other) and
self.dtype == other.dtype and
self.acc_dtype == other.acc_dtype)
def __hash__(self):
return CAReduce.__hash__(self) ^ hash((self.dtype, self.acc_dtype))
......@@ -1968,8 +1973,8 @@ class Prod(CAReduceDtype):
self.no_zeros_in_input = False
def __eq__(self, other):
return (CAReduceDtype.__eq__(self, other)
and self.no_zeros_in_input == other.no_zeros_in_input)
return (CAReduceDtype.__eq__(self, other) and
self.no_zeros_in_input == other.no_zeros_in_input)
def __hash__(self):
return (CAReduceDtype.__hash__(self) ^
......@@ -2124,25 +2129,26 @@ class MulWithoutZeros(scalar.BinaryScalarOp):
def c_code(self, node, name, inp, out, sub):
x, y = inp
z, = out
return (("%(z)s = ((%(x)s == 0) ? (%(y)s) : "
+ "((%(y)s == 0) ? (%(x)s) : ((%(y)s)*(%(x)s))) );")
return (("%(z)s = ((%(x)s == 0) ? (%(y)s) : " +
"((%(y)s == 0) ? (%(x)s) : ((%(y)s)*(%(x)s))) );")
% locals())
def c_code_cache_version(self):
return (1,)
mul_without_zeros = MulWithoutZeros(scalar.upcast_out,
name='mul_without_zeros')
mul_without_zeros = MulWithoutZeros(scalar.upcast_out, name='mul_without_zeros')
class ProdWithoutZeros(CAReduceDtype):
def __init__(self, axis=None, dtype=None, acc_dtype=None):
CAReduceDtype.__init__(self, mul_without_zeros, axis=axis,
dtype=dtype, acc_dtype=acc_dtype)
def grad(self, inp, grads):
a, = inp
a_grad = theano.gradient.grad_not_implemented(self, 0, a,
a_grad = theano.gradient.grad_not_implemented(
self, 0, a,
"2nd derivatives of `product(a)` is not currently supported."
"If `a` is guarenteed to contains no zeros, use `product(a, no_zeros_in_input=True)`."
)
"If `a` is guarenteed to contains no zeros, use "
"`product(a, no_zeros_in_input=True)`.")
return [a_grad]
......@@ -57,7 +57,6 @@ whitelist_flake8 = [
"typed_list/tests/test_type.py",
"typed_list/tests/test_opt.py",
"typed_list/tests/test_basic.py",
"tensor/elemwise.py",
"tensor/xlogx.py",
"tensor/blas_headers.py",
"tensor/utils.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论