提交 a4f0dced authored 作者: Iban Harlouchet's avatar Iban Harlouchet

flake8 tensor/elemwise.py

上级 2c125069
...@@ -11,7 +11,7 @@ from six import iteritems ...@@ -11,7 +11,7 @@ from six import iteritems
from six.moves import xrange from six.moves import xrange
from theano.gof import Apply, Op, OpenMPOp from theano.gof import Apply, Op, OpenMPOp
from theano import scalar from theano import scalar
from theano.scalar import Scalar, get_scalar_type from theano.scalar import get_scalar_type
from theano.printing import pprint from theano.printing import pprint
from theano.tensor.utils import hash_from_dict from theano.tensor.utils import hash_from_dict
from theano.gradient import DisconnectedType from theano.gradient import DisconnectedType
...@@ -50,7 +50,7 @@ def TensorConstant(*inputs, **kwargs): ...@@ -50,7 +50,7 @@ def TensorConstant(*inputs, **kwargs):
################## ##################
### DimShuffle ### # DimShuffle #
################## ##################
class DimShuffle(Op): class DimShuffle(Op):
...@@ -139,8 +139,8 @@ class DimShuffle(Op): ...@@ -139,8 +139,8 @@ class DimShuffle(Op):
raise TypeError("DimShuffle indices must be python ints.") raise TypeError("DimShuffle indices must be python ints.")
if j >= len(input_broadcastable): if j >= len(input_broadcastable):
raise ValueError(("new_order[%d] is %d, but the input " raise ValueError(("new_order[%d] is %d, but the input "
"only has %d axes.") % "only has %d axes.") %
(i, j, len(input_broadcastable))) (i, j, len(input_broadcastable)))
if j in new_order[(i + 1):]: if j in new_order[(i + 1):]:
raise ValueError("The same input dimension may not appear " raise ValueError("The same input dimension may not appear "
"twice in the list of output dimensions", "twice in the list of output dimensions",
...@@ -207,7 +207,7 @@ class DimShuffle(Op): ...@@ -207,7 +207,7 @@ class DimShuffle(Op):
ob.append(ib[value]) ob.append(ib[value])
output = TensorType(dtype=input.type.dtype, output = TensorType(dtype=input.type.dtype,
broadcastable=ob).make_variable() broadcastable=ob).make_variable()
return Apply(self, [input], [output]) return Apply(self, [input], [output])
...@@ -219,12 +219,11 @@ class DimShuffle(Op): ...@@ -219,12 +219,11 @@ class DimShuffle(Op):
and self.input_broadcastable == other.input_broadcastable and self.input_broadcastable == other.input_broadcastable
def _rehash(self): def _rehash(self):
self._hashval = ( self._hashval = (hash(type(self).__name__) ^
hash(type(self).__name__) hash(type(self).__module__) ^
^ hash(type(self).__module__) hash(self.inplace) ^
^ hash(self.inplace) hash(self.new_order) ^
^ hash(self.new_order) hash(self.input_broadcastable))
^ hash(self.input_broadcastable))
def __hash__(self): def __hash__(self):
return self._hashval return self._hashval
...@@ -232,7 +231,7 @@ class DimShuffle(Op): ...@@ -232,7 +231,7 @@ class DimShuffle(Op):
def __str__(self): def __str__(self):
if self.inplace: if self.inplace:
return "InplaceDimShuffle{%s}" % ",".join(str(x) return "InplaceDimShuffle{%s}" % ",".join(str(x)
for x in self.new_order) for x in self.new_order)
else: else:
return "DimShuffle{%s}" % ",".join(str(x) for x in self.new_order) return "DimShuffle{%s}" % ",".join(str(x) for x in self.new_order)
...@@ -286,7 +285,8 @@ class DimShuffle(Op): ...@@ -286,7 +285,8 @@ class DimShuffle(Op):
nd_out = len(self.new_order) nd_out = len(self.new_order)
check_input_nd = [('if (PyArray_NDIM(%(input)s) != ' + str(nd_in) + ')' check_input_nd = [('if (PyArray_NDIM(%(input)s) != ' + str(nd_in) + ')'
'{PyErr_SetString(PyExc_NotImplementedError, "input nd"); %(fail)s;}')] '{PyErr_SetString(PyExc_NotImplementedError, '
'"input nd"); %(fail)s;}')]
clear_output = ['if (%(res)s) {Py_XDECREF(%(res)s);}'] clear_output = ['if (%(res)s) {Py_XDECREF(%(res)s);}']
...@@ -296,8 +296,10 @@ class DimShuffle(Op): ...@@ -296,8 +296,10 @@ class DimShuffle(Op):
get_base = [ get_base = [
'{ PyArrayObject * %(basename)s = %(input)s', 'Py_INCREF((PyObject*)%(basename)s)'] '{ PyArrayObject * %(basename)s = %(input)s', 'Py_INCREF((PyObject*)%(basename)s)']
else: else:
get_base = [('{ PyArrayObject * %(basename)s = (PyArrayObject*)PyArray_FromAny((PyObject*)%(input)s, NULL,' get_base = [('{ PyArrayObject * %(basename)s = '
'0, 0, NPY_ARRAY_ALIGNED|NPY_ARRAY_ENSURECOPY, NULL)')] '(PyArrayObject*)PyArray_FromAny((PyObject*)%(input)s,'
' NULL, 0, 0, NPY_ARRAY_ALIGNED|NPY_ARRAY_ENSURECOPY,'
' NULL)')]
shape_statements = ['npy_intp dimensions[%i]' % nd_out] shape_statements = ['npy_intp dimensions[%i]' % nd_out]
for i, o in enumerate(self.new_order): for i, o in enumerate(self.new_order):
...@@ -312,9 +314,12 @@ class DimShuffle(Op): ...@@ -312,9 +314,12 @@ class DimShuffle(Op):
# set the strides of the non-broadcasted dimensions # set the strides of the non-broadcasted dimensions
for i, o in enumerate(self.new_order): for i, o in enumerate(self.new_order):
if o != 'x': if o != 'x':
strides_statements += [('strides[' + str(i) strides_statements += [('strides[' + str(i) +
+ '] = PyArray_DIMS(%(basename)s)[' + str(o) '] = PyArray_DIMS(%(basename)s)[' +
+ '] == 1? 0 : PyArray_STRIDES(%(basename)s)[' + str(o) + ']')] str(o) +
'] == 1? 0 : '
'PyArray_STRIDES(%(basename)s)[' +
str(o) + ']')]
else: else:
strides_statements += [('strides[' + str(i) + '] = 0')] strides_statements += [('strides[' + str(i) + '] = 0')]
...@@ -360,12 +365,12 @@ PyArray_SetBaseObject(%(res)s, (PyObject*)%(basename)s); ...@@ -360,12 +365,12 @@ PyArray_SetBaseObject(%(res)s, (PyObject*)%(basename)s);
""" """
'}'] '}']
full_code = statements(check_input_nd full_code = statements(check_input_nd +
+ clear_output clear_output +
+ get_base get_base +
+ shape_statements shape_statements +
+ strides_statements strides_statements +
+ close_bracket) close_bracket)
if 0: if 0:
print('C_CODE') print('C_CODE')
...@@ -408,7 +413,7 @@ PyArray_SetBaseObject(%(res)s, (PyObject*)%(basename)s); ...@@ -408,7 +413,7 @@ PyArray_SetBaseObject(%(res)s, (PyObject*)%(basename)s);
class DimShufflePrinter: class DimShufflePrinter:
def __p(self, new_order, pstate, r): def __p(self, new_order, pstate, r):
if new_order != () and new_order[0] == 'x': if new_order != () and new_order[0] == 'x':
return "%s" % self.__p(new_order[1:], pstate, r) return "%s" % self.__p(new_order[1:], pstate, r)
# return "[%s]" % self.__p(new_order[1:], pstate, r) # return "[%s]" % self.__p(new_order[1:], pstate, r)
if list(new_order) == list(range(r.type.ndim)): if list(new_order) == list(range(r.type.ndim)):
...@@ -416,7 +421,7 @@ class DimShufflePrinter: ...@@ -416,7 +421,7 @@ class DimShufflePrinter:
if list(new_order) == list(reversed(range(r.type.ndim))): if list(new_order) == list(reversed(range(r.type.ndim))):
return "%s.T" % pstate.pprinter.process(r) return "%s.T" % pstate.pprinter.process(r)
return "DimShuffle{%s}(%s)" % (", ".join(map(str, new_order)), return "DimShuffle{%s}(%s)" % (", ".join(map(str, new_order)),
pstate.pprinter.process(r)) pstate.pprinter.process(r))
def process(self, r, pstate): def process(self, r, pstate):
if r.owner is None: if r.owner is None:
...@@ -428,11 +433,11 @@ class DimShufflePrinter: ...@@ -428,11 +433,11 @@ class DimShufflePrinter:
raise TypeError("Can only print DimShuffle.") raise TypeError("Can only print DimShuffle.")
pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, DimShuffle), pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, DimShuffle),
DimShufflePrinter()) DimShufflePrinter())
################ ################
### Elemwise ### # Elemwise #
################ ################
class Elemwise(OpenMPOp): class Elemwise(OpenMPOp):
...@@ -496,7 +501,7 @@ class Elemwise(OpenMPOp): ...@@ -496,7 +501,7 @@ class Elemwise(OpenMPOp):
self.nfunc = getattr(numpy, nfunc_spec[0]) self.nfunc = getattr(numpy, nfunc_spec[0])
elif scalar_op.nin > 0: elif scalar_op.nin > 0:
self.ufunc = numpy.frompyfunc(scalar_op.impl, scalar_op.nin, self.ufunc = numpy.frompyfunc(scalar_op.impl, scalar_op.nin,
scalar_op.nout) scalar_op.nout)
# precompute the hash of this node # precompute the hash of this node
self._rehash() self._rehash()
...@@ -518,7 +523,8 @@ class Elemwise(OpenMPOp): ...@@ -518,7 +523,8 @@ class Elemwise(OpenMPOp):
self.nfunc = getattr(numpy, self.nfunc_spec[0]) self.nfunc = getattr(numpy, self.nfunc_spec[0])
elif self.scalar_op.nin > 0: elif self.scalar_op.nin > 0:
self.ufunc = numpy.frompyfunc(self.scalar_op.impl, self.ufunc = numpy.frompyfunc(self.scalar_op.impl,
self.scalar_op.nin, self.scalar_op.nout) self.scalar_op.nin,
self.scalar_op.nout)
self._rehash() self._rehash()
def make_node(self, *inputs): def make_node(self, *inputs):
...@@ -557,15 +563,16 @@ class Elemwise(OpenMPOp): ...@@ -557,15 +563,16 @@ class Elemwise(OpenMPOp):
# it is multiplied by nout because Elemwise supports multiple outputs # it is multiplied by nout because Elemwise supports multiple outputs
# (nout of them) # (nout of them)
out_broadcastables = [[all(bcast) out_broadcastables = [[all(bcast)
for bcast in izip(*[input.type.broadcastable for bcast in
for input in inputs])]] * shadow.nout izip(*[input.type.broadcastable
for input in inputs])]] * shadow.nout
# inplace_pattern maps output idx -> input idx # inplace_pattern maps output idx -> input idx
inplace_pattern = self.inplace_pattern inplace_pattern = self.inplace_pattern
if inplace_pattern: if inplace_pattern:
for overwriter, overwritten in iteritems(inplace_pattern): for overwriter, overwritten in iteritems(inplace_pattern):
for ob, ib in izip(out_broadcastables[overwriter], for ob, ib in izip(out_broadcastables[overwriter],
inputs[overwritten].type.broadcastable): inputs[overwritten].type.broadcastable):
if ib and not ob: if ib and not ob:
raise ValueError( raise ValueError(
"Operation cannot be done inplace on an input " "Operation cannot be done inplace on an input "
...@@ -579,8 +586,8 @@ class Elemwise(OpenMPOp): ...@@ -579,8 +586,8 @@ class Elemwise(OpenMPOp):
([i.type.dtype for i in inputs], out_dtypes, inplace_pattern))) ([i.type.dtype for i in inputs], out_dtypes, inplace_pattern)))
outputs = [TensorType(dtype=dtype, broadcastable=broadcastable)() outputs = [TensorType(dtype=dtype, broadcastable=broadcastable)()
for dtype, broadcastable in izip(out_dtypes, out_broadcastables) for dtype, broadcastable in izip(out_dtypes,
] out_broadcastables)]
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
def __eq__(self, other): def __eq__(self, other):
...@@ -589,8 +596,8 @@ class Elemwise(OpenMPOp): ...@@ -589,8 +596,8 @@ class Elemwise(OpenMPOp):
other_items = list(other.inplace_pattern.items()) other_items = list(other.inplace_pattern.items())
items.sort() items.sort()
other_items.sort() other_items.sort()
rval = ((self.scalar_op == other.scalar_op) rval = ((self.scalar_op == other.scalar_op) and
and (items == other_items)) (items == other_items))
return rval return rval
return False return False
...@@ -628,7 +635,7 @@ class Elemwise(OpenMPOp): ...@@ -628,7 +635,7 @@ class Elemwise(OpenMPOp):
rop_out = None rop_out = None
for jdx, (inp, eval_point) in enumerate(izip(inputs, for jdx, (inp, eval_point) in enumerate(izip(inputs,
eval_points)): eval_points)):
# if None, then we can just ignore this branch .. # if None, then we can just ignore this branch ..
# what we do is to assume that for any non-differentiable # what we do is to assume that for any non-differentiable
# branch, the gradient is actually 0, which I think is not # branch, the gradient is actually 0, which I think is not
...@@ -668,7 +675,7 @@ class Elemwise(OpenMPOp): ...@@ -668,7 +675,7 @@ class Elemwise(OpenMPOp):
# to the gradient.grad method when the outputs have # to the gradient.grad method when the outputs have
# some integer and some floating point outputs # some integer and some floating point outputs
if False in [str(out.type.dtype).find('int') == -1 if False in [str(out.type.dtype).find('int') == -1
for out in outs]: for out in outs]:
# For integer output, return value may # For integer output, return value may
# only be zero or undefined # only be zero or undefined
# We don't bother with trying to check # We don't bother with trying to check
...@@ -699,7 +706,7 @@ class Elemwise(OpenMPOp): ...@@ -699,7 +706,7 @@ class Elemwise(OpenMPOp):
# we can sum over them # we can sum over them
# todo: only count dimensions that were effectively broadcasted # todo: only count dimensions that were effectively broadcasted
to_sum = [j for j, bcast in enumerate(ipt.type.broadcastable) to_sum = [j for j, bcast in enumerate(ipt.type.broadcastable)
if bcast] if bcast]
if to_sum: if to_sum:
shuffle = [] shuffle = []
...@@ -714,7 +721,7 @@ class Elemwise(OpenMPOp): ...@@ -714,7 +721,7 @@ class Elemwise(OpenMPOp):
# close for # close for
sr = Sum(axis=to_sum)(rval[i]) sr = Sum(axis=to_sum)(rval[i])
sr = sr.dimshuffle(shuffle) sr = sr.dimshuffle(shuffle)
#sr = DimShuffle(sr.type.broadcastable, shuffle)(sr) # sr = DimShuffle(sr.type.broadcastable, shuffle)(sr)
rval[i] = sr rval[i] = sr
# close if # close if
# close for # close for
...@@ -747,7 +754,7 @@ class Elemwise(OpenMPOp): ...@@ -747,7 +754,7 @@ class Elemwise(OpenMPOp):
if not isinstance(scalar_igrads, (list, tuple)): if not isinstance(scalar_igrads, (list, tuple)):
raise TypeError('%s.grad returned %s instead of list or tuple' % raise TypeError('%s.grad returned %s instead of list or tuple' %
(str(self.scalar_op), str(type(scalar_igrads)))) (str(self.scalar_op), str(type(scalar_igrads))))
nd = len(inputs[0].type.broadcastable) # this is the same for everyone nd = len(inputs[0].type.broadcastable) # this is the same for everyone
...@@ -787,9 +794,8 @@ class Elemwise(OpenMPOp): ...@@ -787,9 +794,8 @@ class Elemwise(OpenMPOp):
# should be disabled. # should be disabled.
super(Elemwise, self).perform(node, inputs, output_storage) super(Elemwise, self).perform(node, inputs, output_storage)
maxsize = max(len(input.shape) for input in inputs)
for dims in izip(*[list(zip(input.shape, sinput.type.broadcastable)) for dims in izip(*[list(zip(input.shape, sinput.type.broadcastable))
for input, sinput in zip(inputs, node.inputs)]): for input, sinput in zip(inputs, node.inputs)]):
if max(d for d, b in dims) != 1 and (1, False) in dims: if max(d for d, b in dims) != 1 and (1, False) in dims:
# yes there may be more compact ways to write this code, # yes there may be more compact ways to write this code,
# but please maintain python 2.4 compatibility # but please maintain python 2.4 compatibility
...@@ -1115,7 +1121,7 @@ class Elemwise(OpenMPOp): ...@@ -1115,7 +1121,7 @@ class Elemwise(OpenMPOp):
# use it! The scalar_op need to check the broadcast flag himself. # use it! The scalar_op need to check the broadcast flag himself.
if (all([o.ndim >= 1 for o in node.outputs]) and if (all([o.ndim >= 1 for o in node.outputs]) and
# Don't use the contig code for broadcasted scalar. # Don't use the contig code for broadcasted scalar.
not all(node.outputs[0].broadcastable)): not all(node.outputs[0].broadcastable)):
contig = None contig = None
try: try:
contig = self.scalar_op.c_code_contiguous( contig = self.scalar_op.c_code_contiguous(
...@@ -1192,19 +1198,20 @@ class Elemwise(OpenMPOp): ...@@ -1192,19 +1198,20 @@ class Elemwise(OpenMPOp):
return self.scalar_op.c_support_code() return self.scalar_op.c_support_code()
def c_support_code_apply(self, node, nodename): def c_support_code_apply(self, node, nodename):
support_code = self.scalar_op.c_support_code_apply(node, support_code = self.scalar_op.c_support_code_apply(node, nodename +
nodename + '_scalar_') '_scalar_')
return support_code return support_code
def c_code_cache_version_apply(self, node): def c_code_cache_version_apply(self, node):
version = [12] # the version corresponding to the c code in this Op version = [12] # the version corresponding to the c code in this Op
# now we insert versions for the ops on which we depend... # now we insert versions for the ops on which we depend...
scalar_node = Apply(self.scalar_op, scalar_node = Apply(
[get_scalar_type(dtype=input.type.dtype).make_variable() self.scalar_op,
for input in node.inputs], [get_scalar_type(dtype=input.type.dtype).make_variable()
[get_scalar_type(dtype=output.type.dtype).make_variable() for input in node.inputs],
for output in node.outputs]) [get_scalar_type(dtype=output.type.dtype).make_variable()
for output in node.outputs])
version.append(self.scalar_op.c_code_cache_version_apply(scalar_node)) version.append(self.scalar_op.c_code_cache_version_apply(scalar_node))
for i in node.inputs + node.outputs: for i in node.inputs + node.outputs:
version.append(get_scalar_type(dtype=i.type.dtype).c_code_cache_version()) version.append(get_scalar_type(dtype=i.type.dtype).c_code_cache_version())
...@@ -1233,7 +1240,7 @@ class Elemwise(OpenMPOp): ...@@ -1233,7 +1240,7 @@ class Elemwise(OpenMPOp):
################ ################
### CAReduce ### # CAReduce #
################ ################
class CAReduce(Op): class CAReduce(Op):
...@@ -1325,8 +1332,8 @@ class CAReduce(Op): ...@@ -1325,8 +1332,8 @@ class CAReduce(Op):
if self.axis is not None: if self.axis is not None:
for axis in self.axis: for axis in self.axis:
if (axis >= input.type.ndim if (axis >= input.type.ndim or
or (axis < 0 and abs(axis) > input.type.ndim)): (axis < 0 and abs(axis) > input.type.ndim)):
raise ValueError(( raise ValueError((
'Not enough dimensions on %s to reduce on axis %s' 'Not enough dimensions on %s to reduce on axis %s'
% (input, axis))) % (input, axis)))
...@@ -1366,9 +1373,9 @@ class CAReduce(Op): ...@@ -1366,9 +1373,9 @@ class CAReduce(Op):
self.set_ufunc(self.scalar_op) self.set_ufunc(self.scalar_op)
def __eq__(self, other): def __eq__(self, other):
return (type(self) == type(other) return (type(self) == type(other) and
and self.scalar_op == other.scalar_op self.scalar_op == other.scalar_op and
and self.axis == other.axis) self.axis == other.axis)
def __hash__(self): def __hash__(self):
if self.axis is None: if self.axis is None:
...@@ -1420,13 +1427,13 @@ class CAReduce(Op): ...@@ -1420,13 +1427,13 @@ class CAReduce(Op):
# was built with "frompyfunc". We need to find out if we # was built with "frompyfunc". We need to find out if we
# are in one of these cases (only "object" is supported in # are in one of these cases (only "object" is supported in
# the output). # the output).
if ((self.ufunc.ntypes == 1) if ((self.ufunc.ntypes == 1) and
and (self.ufunc.types[0][-1] == 'O')): (self.ufunc.types[0][-1] == 'O')):
variable = self.ufunc.reduce(variable, dimension, variable = self.ufunc.reduce(variable, dimension,
dtype='object') dtype='object')
else: else:
variable = self.ufunc.reduce(variable, dimension, variable = self.ufunc.reduce(variable, dimension,
dtype=acc_dtype) dtype=acc_dtype)
variable = numpy.asarray(variable) variable = numpy.asarray(variable)
if numpy.may_share_memory(variable, input): if numpy.may_share_memory(variable, input):
...@@ -1434,7 +1441,7 @@ class CAReduce(Op): ...@@ -1434,7 +1441,7 @@ class CAReduce(Op):
# We don't want this. # We don't want this.
variable = variable.copy() variable = variable.copy()
output[0] = theano._asarray(variable, output[0] = theano._asarray(variable,
dtype=node.outputs[0].type.dtype) dtype=node.outputs[0].type.dtype)
else: else:
# Force a copy # Force a copy
output[0] = numpy.array(variable, copy=True, output[0] = numpy.array(variable, copy=True,
...@@ -1568,27 +1575,25 @@ for(int i=0;i<PyArray_NDIM(%(iname)s);i++){ ...@@ -1568,27 +1575,25 @@ for(int i=0;i<PyArray_NDIM(%(iname)s);i++){
""" % locals() """ % locals()
else: else:
raise TypeError( raise TypeError(
"The CAReduce.scalar_op must have an identity field.") "The CAReduce.scalar_op must have an identity field.")
task0_decl = ( task0_decl = ("%(dtype)s& %(name)s_i = *%(name)s_iter;\n"
"%(dtype)s& %(name)s_i = *%(name)s_iter;\n" "%(name)s_i = %(identity)s;"
"%(name)s_i = %(identity)s;" % dict(dtype=adtype, name=aname, identity=identity))
% dict(dtype=adtype, name=aname, identity=identity))
task1_decl = ("%(dtype)s& %(name)s_i = *%(name)s_iter;\n" task1_decl = ("%(dtype)s& %(name)s_i = *%(name)s_iter;\n"
% dict(dtype=idtype, name=inames[0])) % dict(dtype=idtype, name=inames[0]))
task1_code = self.scalar_op.c_code( task1_code = self.scalar_op.c_code(
Apply( Apply(self.scalar_op,
self.scalar_op, [get_scalar_type(dtype=input.type.dtype).make_variable()
[get_scalar_type(dtype=input.type.dtype).make_variable() for input in (node.inputs * 2)],
for input in (node.inputs * 2)], [get_scalar_type(dtype=output.type.dtype).make_variable()
[get_scalar_type(dtype=output.type.dtype).make_variable() for input in node.outputs]),
for input in node.outputs]), None,
None, ["%s_i" % aname, "%s_i" % inames[0]],
["%s_i" % aname, "%s_i" % inames[0]], ["%s_i" % aname],
["%s_i" % aname], sub)
sub)
code1 = """ code1 = """
{ {
%(task1_decl)s %(task1_decl)s
...@@ -1600,11 +1605,10 @@ for(int i=0;i<PyArray_NDIM(%(iname)s);i++){ ...@@ -1600,11 +1605,10 @@ for(int i=0;i<PyArray_NDIM(%(iname)s);i++){
if len(axis) == 1: if len(axis) == 1:
all_code = [("", "")] * nnested + [(task0_decl, code1), ""] all_code = [("", "")] * nnested + [(task0_decl, code1), ""]
else: else:
all_code = ( all_code = ([("", "")] * nnested +
[("", "")] * nnested [(task0_decl, "")] +
+ [(task0_decl, "")] [("", "")] * (len(axis) - 2) +
+ [("", "")] * (len(axis) - 2) [("", code1), ""])
+ [("", code1), ""])
else: else:
all_code = [task0_decl + code1] all_code = [task0_decl + code1]
loop = cgen.make_loop_careduce( loop = cgen.make_loop_careduce(
...@@ -1632,11 +1636,12 @@ for(int i=0;i<PyArray_NDIM(%(iname)s);i++){ ...@@ -1632,11 +1636,12 @@ for(int i=0;i<PyArray_NDIM(%(iname)s);i++){
version = [5] # the version corresponding to the c code in this Op version = [5] # the version corresponding to the c code in this Op
# now we insert versions for the ops on which we depend... # now we insert versions for the ops on which we depend...
scalar_node = Apply(self.scalar_op, scalar_node = Apply(
[get_scalar_type(dtype=input.type.dtype).make_variable() self.scalar_op,
for input in node.inputs], [get_scalar_type(dtype=input.type.dtype).make_variable()
[get_scalar_type(dtype=output.type.dtype).make_variable() for input in node.inputs],
for output in node.outputs]) [get_scalar_type(dtype=output.type.dtype).make_variable()
for output in node.outputs])
version.append(self.scalar_op.c_code_cache_version_apply(scalar_node)) version.append(self.scalar_op.c_code_cache_version_apply(scalar_node))
for i in node.inputs + node.outputs: for i in node.inputs + node.outputs:
version.append(get_scalar_type(dtype=i.type.dtype).c_code_cache_version()) version.append(get_scalar_type(dtype=i.type.dtype).c_code_cache_version())
...@@ -1760,9 +1765,9 @@ class CAReduceDtype(CAReduce): ...@@ -1760,9 +1765,9 @@ class CAReduceDtype(CAReduce):
self.acc_dtype = acc_dtype self.acc_dtype = acc_dtype
def __eq__(self, other): def __eq__(self, other):
return (CAReduce.__eq__(self, other) return (CAReduce.__eq__(self, other) and
and self.dtype == other.dtype self.dtype == other.dtype and
and self.acc_dtype == other.acc_dtype) self.acc_dtype == other.acc_dtype)
def __hash__(self): def __hash__(self):
return CAReduce.__hash__(self) ^ hash((self.dtype, self.acc_dtype)) return CAReduce.__hash__(self) ^ hash((self.dtype, self.acc_dtype))
...@@ -1968,8 +1973,8 @@ class Prod(CAReduceDtype): ...@@ -1968,8 +1973,8 @@ class Prod(CAReduceDtype):
self.no_zeros_in_input = False self.no_zeros_in_input = False
def __eq__(self, other): def __eq__(self, other):
return (CAReduceDtype.__eq__(self, other) return (CAReduceDtype.__eq__(self, other) and
and self.no_zeros_in_input == other.no_zeros_in_input) self.no_zeros_in_input == other.no_zeros_in_input)
def __hash__(self): def __hash__(self):
return (CAReduceDtype.__hash__(self) ^ return (CAReduceDtype.__hash__(self) ^
...@@ -2124,25 +2129,26 @@ class MulWithoutZeros(scalar.BinaryScalarOp): ...@@ -2124,25 +2129,26 @@ class MulWithoutZeros(scalar.BinaryScalarOp):
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
x, y = inp x, y = inp
z, = out z, = out
return (("%(z)s = ((%(x)s == 0) ? (%(y)s) : " return (("%(z)s = ((%(x)s == 0) ? (%(y)s) : " +
+ "((%(y)s == 0) ? (%(x)s) : ((%(y)s)*(%(x)s))) );") "((%(y)s == 0) ? (%(x)s) : ((%(y)s)*(%(x)s))) );")
% locals()) % locals())
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (1,)
mul_without_zeros = MulWithoutZeros(scalar.upcast_out, mul_without_zeros = MulWithoutZeros(scalar.upcast_out, name='mul_without_zeros')
name='mul_without_zeros')
class ProdWithoutZeros(CAReduceDtype): class ProdWithoutZeros(CAReduceDtype):
def __init__(self, axis=None, dtype=None, acc_dtype=None): def __init__(self, axis=None, dtype=None, acc_dtype=None):
CAReduceDtype.__init__(self, mul_without_zeros, axis=axis, CAReduceDtype.__init__(self, mul_without_zeros, axis=axis,
dtype=dtype, acc_dtype=acc_dtype) dtype=dtype, acc_dtype=acc_dtype)
def grad(self, inp, grads): def grad(self, inp, grads):
a, = inp a, = inp
a_grad = theano.gradient.grad_not_implemented(self, 0, a, a_grad = theano.gradient.grad_not_implemented(
"2nd derivatives of `product(a)` is not currently supported." self, 0, a,
"If `a` is guarenteed to contains no zeros, use `product(a, no_zeros_in_input=True)`." "2nd derivatives of `product(a)` is not currently supported."
) "If `a` is guarenteed to contains no zeros, use "
"`product(a, no_zeros_in_input=True)`.")
return [a_grad] return [a_grad]
...@@ -57,7 +57,6 @@ whitelist_flake8 = [ ...@@ -57,7 +57,6 @@ whitelist_flake8 = [
"typed_list/tests/test_type.py", "typed_list/tests/test_type.py",
"typed_list/tests/test_opt.py", "typed_list/tests/test_opt.py",
"typed_list/tests/test_basic.py", "typed_list/tests/test_basic.py",
"tensor/elemwise.py",
"tensor/xlogx.py", "tensor/xlogx.py",
"tensor/blas_headers.py", "tensor/blas_headers.py",
"tensor/utils.py", "tensor/utils.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论