提交 bd168a56 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #3127 from harlouci/flake8_scalar

Flake8 scalar
...@@ -242,22 +242,21 @@ class Scalar(Type): ...@@ -242,22 +242,21 @@ class Scalar(Type):
print(dtype, np.zeros(1, dtype=dtype).dtype.num) print(dtype, np.zeros(1, dtype=dtype).dtype.num)
""" """
return { # dtype: (py_type, c_type, cls_name) return { # dtype: (py_type, c_type, cls_name)
'float16': (numpy.float16, 'npy_float16', 'Float16'), 'float16': (numpy.float16, 'npy_float16', 'Float16'),
'float32': (numpy.float32, 'npy_float32', 'Float32'), 'float32': (numpy.float32, 'npy_float32', 'Float32'),
'float64': (numpy.float64, 'npy_float64', 'Float64'), 'float64': (numpy.float64, 'npy_float64', 'Float64'),
'complex128': (numpy.complex128, 'theano_complex128', 'complex128': (numpy.complex128, 'theano_complex128',
'Complex128'), 'Complex128'),
'complex64': (numpy.complex64, 'theano_complex64', 'complex64': (numpy.complex64, 'theano_complex64', 'Complex64'),
'Complex64'), 'uint8': (numpy.uint8, 'npy_uint8', 'UInt8'),
'uint8': (numpy.uint8, 'npy_uint8', 'UInt8'), 'int8': (numpy.int8, 'npy_int8', 'Int8'),
'int8': (numpy.int8, 'npy_int8', 'Int8'), 'uint16': (numpy.uint16, 'npy_uint16', 'UInt16'),
'uint16': (numpy.uint16, 'npy_uint16', 'UInt16'), 'int16': (numpy.int16, 'npy_int16', 'Int16'),
'int16': (numpy.int16, 'npy_int16', 'Int16'), 'uint32': (numpy.uint32, 'npy_uint32', 'UInt32'),
'uint32': (numpy.uint32, 'npy_uint32', 'UInt32'), 'int32': (numpy.int32, 'npy_int32', 'Int32'),
'int32': (numpy.int32, 'npy_int32', 'Int32'), 'uint64': (numpy.uint64, 'npy_uint64', 'UInt64'),
'uint64': (numpy.uint64, 'npy_uint64', 'UInt64'), 'int64': (numpy.int64, 'npy_int64', 'Int64')
'int64': (numpy.int64, 'npy_int64', 'Int64') }[self.dtype]
}[self.dtype]
except KeyError: except KeyError:
raise TypeError("Unsupported dtype for %s: %s" % ( raise TypeError("Unsupported dtype for %s: %s" % (
self.__class__.__name__, self.dtype)) self.__class__.__name__, self.dtype))
...@@ -348,7 +347,7 @@ class Scalar(Type): ...@@ -348,7 +347,7 @@ class Scalar(Type):
# 'npy_intX', some C code may not compile, e.g. when assigning # 'npy_intX', some C code may not compile, e.g. when assigning
# the value 0 (cast to 'int' in C) to a theano_complex64. # the value 0 (cast to 'int' in C) to a theano_complex64.
if (numpy.dtype('intc').num not in if (numpy.dtype('intc').num not in
[numpy.dtype(d[4:]).num for d in real_types]): [numpy.dtype(d[4:]).num for d in real_types]):
# In that case we add the 'int' type to the real types. # In that case we add the 'int' type to the real types.
real_types.append('int') real_types.append('int')
...@@ -421,12 +420,12 @@ class Scalar(Type): ...@@ -421,12 +420,12 @@ class Scalar(Type):
{ this->real=y.real; this->imag=y.imag; return *this; } { this->real=y.real; this->imag=y.imag; return *this; }
''' % dict(mytype=mytype, othertype=othertype) ''' % dict(mytype=mytype, othertype=othertype)
operator_eq = ''.join(operator_eq_real(ctype, rtype) operator_eq = (''.join(operator_eq_real(ctype, rtype)
for ctype in cplx_types for ctype in cplx_types
for rtype in real_types) \ for rtype in real_types) +
+ ''.join(operator_eq_cplx(ctype1, ctype2) ''.join(operator_eq_cplx(ctype1, ctype2)
for ctype1 in cplx_types for ctype1 in cplx_types
for ctype2 in cplx_types) for ctype2 in cplx_types))
# We are not using C++ generic templating here, because this would # We are not using C++ generic templating here, because this would
# generate two different functions for adding a complex64 and a # generate two different functions for adding a complex64 and a
...@@ -473,12 +472,12 @@ class Scalar(Type): ...@@ -473,12 +472,12 @@ class Scalar(Type):
for ctype in cplx_types for ctype in cplx_types
for rtype in real_types) for rtype in real_types)
return template % dict(nbits=64, half_nbits=32) \ return (template % dict(nbits=64, half_nbits=32) +
+ template % dict(nbits=128, half_nbits=64) \ template % dict(nbits=128, half_nbits=64) +
+ operator_eq \ operator_eq +
+ operator_plus \ operator_plus +
+ operator_minus \ operator_minus +
+ operator_mul operator_mul)
else: else:
return "" return ""
...@@ -544,9 +543,9 @@ class _scalar_py_operators: ...@@ -544,9 +543,9 @@ class _scalar_py_operators:
return neg(self) return neg(self)
# CASTS # CASTS
#def __int__(self): return AsInt(self).out # def __int__(self): return AsInt(self).out
#def __float__(self): return AsDouble(self).out # def __float__(self): return AsDouble(self).out
#def __complex__(self): return AsComplex(self).out # def __complex__(self): return AsComplex(self).out
# BITWISE # BITWISE
def __invert__(self): def __invert__(self):
...@@ -583,7 +582,7 @@ class _scalar_py_operators: ...@@ -583,7 +582,7 @@ class _scalar_py_operators:
def __ge__(self, other): def __ge__(self, other):
return ge(self, other) return ge(self, other)
#ARITHMETIC - NORMAL # ARITHMETIC - NORMAL
def __add__(self, other): def __add__(self, other):
return add(self, other) return add(self, other)
...@@ -609,7 +608,7 @@ class _scalar_py_operators: ...@@ -609,7 +608,7 @@ class _scalar_py_operators:
def __pow__(self, other): def __pow__(self, other):
return pow(self, other) return pow(self, other)
#ARITHMETIC - RIGHT-OPERAND # ARITHMETIC - RIGHT-OPERAND
def __radd__(self, other): def __radd__(self, other):
return add(other, self) return add(other, self)
...@@ -694,7 +693,7 @@ class upgrade_to_float(object): ...@@ -694,7 +693,7 @@ class upgrade_to_float(object):
uint32: float64, uint32: float64,
uint64: float64} uint64: float64}
return get_scalar_type(Scalar.upcast(*[conv.get(type, type) return get_scalar_type(Scalar.upcast(*[conv.get(type, type)
for type in types])), for type in types])),
class same_out(object): class same_out(object):
...@@ -891,9 +890,9 @@ class ScalarOp(Op): ...@@ -891,9 +890,9 @@ class ScalarOp(Op):
self.__class__.__name__) self.__class__.__name__)
def __eq__(self, other): def __eq__(self, other):
test = type(self) == type(other) \ test = (type(self) == type(other) and
and getattr(self, 'output_types_preference', None) \ getattr(self, 'output_types_preference', None) ==
== getattr(other, 'output_types_preference', None) getattr(other, 'output_types_preference', None))
return test return test
def __hash__(self): def __hash__(self):
...@@ -942,9 +941,9 @@ class UnaryScalarOp(ScalarOp): ...@@ -942,9 +941,9 @@ class UnaryScalarOp(ScalarOp):
(x,) = inputs (x,) = inputs
(z,) = outputs (z,) = outputs
if (not theano.config.lib.amdlibm or if (not theano.config.lib.amdlibm or
# We compare the dtype AND the broadcast flag # We compare the dtype AND the broadcast flag
# as this function do not broadcast # as this function do not broadcast
node.inputs[0].type != node.outputs[0].type): node.inputs[0].type != node.outputs[0].type):
raise theano.gof.utils.MethodNotDefined() raise theano.gof.utils.MethodNotDefined()
dtype = node.inputs[0].type.dtype_specs()[1] dtype = node.inputs[0].type.dtype_specs()[1]
...@@ -1176,7 +1175,7 @@ class InRange(LogicalComparison): ...@@ -1176,7 +1175,7 @@ class InRange(LogicalComparison):
cmp1 = '>=' cmp1 = '>='
# backport # backport
#cmp1 = '>' if self.openlow else '>=' # cmp1 = '>' if self.openlow else '>='
if self.openhi: if self.openhi:
cmp2 = '<' cmp2 = '<'
...@@ -1184,14 +1183,14 @@ class InRange(LogicalComparison): ...@@ -1184,14 +1183,14 @@ class InRange(LogicalComparison):
cmp2 = '<=' cmp2 = '<='
# backport # backport
#cmp2 = '<' if self.openhi else '<=' # cmp2 = '<' if self.openhi else '<='
return ("%(z)s = %(x)s %(cmp1)s %(low)s &&" return ("%(z)s = %(x)s %(cmp1)s %(low)s &&"
" %(x)s %(cmp2)s %(hi)s;" % locals()) " %(x)s %(cmp2)s %(hi)s;" % locals())
def get_grad(self, elem): def get_grad(self, elem):
if elem.type in complex_types: if elem.type in complex_types:
msg = "No gradient implemented for complex numbers in\ msg = ("No gradient implemented for complex numbers in "
class scalar.basic.InRange" "class scalar.basic.InRange")
raise NotImplementedError(msg) raise NotImplementedError(msg)
elif elem.type in discrete_types: elif elem.type in discrete_types:
return elem.zeros_like().astype(theano.config.floatX) return elem.zeros_like().astype(theano.config.floatX)
...@@ -1473,7 +1472,7 @@ class Mul(ScalarOp): ...@@ -1473,7 +1472,7 @@ class Mul(ScalarOp):
# output is complex. The rest of this function make this supposition. # output is complex. The rest of this function make this supposition.
output_type = self.output_types([i.type for i in inputs])[0] output_type = self.output_types([i.type for i in inputs])[0]
if output_type in complex_types: if output_type in complex_types:
if not gz.type in complex_types: if gz.type not in complex_types:
raise TypeError( raise TypeError(
'Mul with output_type ' + str(output_type) + 'Mul with output_type ' + str(output_type) +
' expected gz type to be complex, got gz with type ' + ' expected gz type to be complex, got gz with type ' +
...@@ -1600,7 +1599,7 @@ class TrueDiv(BinaryScalarOp): ...@@ -1600,7 +1599,7 @@ class TrueDiv(BinaryScalarOp):
node.inputs[1].type in complex_types]) == 1: node.inputs[1].type in complex_types]) == 1:
raise NotImplementedError('type not supported', type) raise NotImplementedError('type not supported', type)
if (node.inputs[0].type in discrete_types and if (node.inputs[0].type in discrete_types and
node.inputs[1].type in discrete_types): node.inputs[1].type in discrete_types):
return "%(z)s = ((double)%(x)s) / %(y)s;" % locals() return "%(z)s = ((double)%(x)s) / %(y)s;" % locals()
return "%(z)s = %(x)s / %(y)s;" % locals() return "%(z)s = %(x)s / %(y)s;" % locals()
...@@ -1710,7 +1709,7 @@ floor_div = int_div ...@@ -1710,7 +1709,7 @@ floor_div = int_div
def mod_check(x, y): def mod_check(x, y):
if (as_scalar(x).type in complex_types or if (as_scalar(x).type in complex_types or
as_scalar(y).type in complex_types): as_scalar(y).type in complex_types):
# Currently forbidden. # Currently forbidden.
raise Mod.complex_error raise Mod.complex_error
else: else:
...@@ -1808,7 +1807,7 @@ class Pow(BinaryScalarOp): ...@@ -1808,7 +1807,7 @@ class Pow(BinaryScalarOp):
(x, y) = inputs (x, y) = inputs
(z,) = outputs (z,) = outputs
if (node.inputs[0].type in complex_types or if (node.inputs[0].type in complex_types or
node.inputs[1].type in complex_types): node.inputs[1].type in complex_types):
raise NotImplementedError('type not supported', type) raise NotImplementedError('type not supported', type)
return "%(z)s = pow(%(x)s, %(y)s);" % locals() return "%(z)s = pow(%(x)s, %(y)s);" % locals()
...@@ -1838,10 +1837,10 @@ class Pow(BinaryScalarOp): ...@@ -1838,10 +1837,10 @@ class Pow(BinaryScalarOp):
# We compare the dtype AND the broadcast flag # We compare the dtype AND the broadcast flag
# as this function do not broadcast # as this function do not broadcast
if (node.inputs[0].type == node.outputs[0].type and if (node.inputs[0].type == node.outputs[0].type and
node.inputs[1].type == node.outputs[0].type and node.inputs[1].type == node.outputs[0].type and
# amdlibm 3.0 do not have a float64 version of this SIMD function # amdlibm 3.0 do not have a float64 version of this SIMD function
node.inputs[0].dtype == 'float32' and node.inputs[0].dtype == 'float32' and
node.inputs[1].dtype == 'float32'): node.inputs[1].dtype == 'float32'):
dtype = 'float' dtype = 'float'
fct = "amd_vrsa_powf" fct = "amd_vrsa_powf"
return """ return """
...@@ -2014,19 +2013,19 @@ convert_to_complex64 = Cast(complex64, name='convert_to_complex64') ...@@ -2014,19 +2013,19 @@ convert_to_complex64 = Cast(complex64, name='convert_to_complex64')
convert_to_complex128 = Cast(complex128, name='convert_to_complex128') convert_to_complex128 = Cast(complex128, name='convert_to_complex128')
_cast_mapping = { _cast_mapping = {
'int8': convert_to_int8, 'int8': convert_to_int8,
'int16': convert_to_int16, 'int16': convert_to_int16,
'int32': convert_to_int32, 'int32': convert_to_int32,
'int64': convert_to_int64, 'int64': convert_to_int64,
'uint8': convert_to_uint8, 'uint8': convert_to_uint8,
'uint16': convert_to_uint16, 'uint16': convert_to_uint16,
'uint32': convert_to_uint32, 'uint32': convert_to_uint32,
'uint64': convert_to_uint64, 'uint64': convert_to_uint64,
'float16': convert_to_float16, 'float16': convert_to_float16,
'float32': convert_to_float32, 'float32': convert_to_float32,
'float64': convert_to_float64, 'float64': convert_to_float64,
'complex64': convert_to_complex64, 'complex64': convert_to_complex64,
'complex128': convert_to_complex128} 'complex128': convert_to_complex128}
def cast(x, dtype): def cast(x, dtype):
...@@ -2201,7 +2200,7 @@ class RoundHalfToEven(UnaryScalarOp): ...@@ -2201,7 +2200,7 @@ class RoundHalfToEven(UnaryScalarOp):
(x,) = inputs (x,) = inputs
(z,) = outputs (z,) = outputs
typ = node.outputs[0].type.dtype typ = node.outputs[0].type.dtype
if not typ in ['float32', 'float64']: if typ not in ['float32', 'float64']:
Exception("The output should be float32 or float64") Exception("The output should be float32 or float64")
return dedent(""" return dedent("""
...@@ -2946,7 +2945,7 @@ class ArcTan2(BinaryScalarOp): ...@@ -2946,7 +2945,7 @@ class ArcTan2(BinaryScalarOp):
(y, x) = inputs (y, x) = inputs
(z,) = outputs (z,) = outputs
if (node.inputs[0].type in complex_types or if (node.inputs[0].type in complex_types or
node.inputs[1].type in complex_types): node.inputs[1].type in complex_types):
raise NotImplementedError('type not supported', type) raise NotImplementedError('type not supported', type)
return "%(z)s = atan2(%(y)s, %(x)s);" % locals() return "%(z)s = atan2(%(y)s, %(x)s);" % locals()
arctan2 = ArcTan2(upgrade_to_float, name='arctan2') arctan2 = ArcTan2(upgrade_to_float, name='arctan2')
...@@ -3309,7 +3308,7 @@ class Composite(ScalarOp): ...@@ -3309,7 +3308,7 @@ class Composite(ScalarOp):
"All orphans in the fgraph to Composite must" "All orphans in the fgraph to Composite must"
" be Constant instances.") " be Constant instances.")
elif (any(i.dtype == 'float16' for i in var.owner.inputs) or elif (any(i.dtype == 'float16' for i in var.owner.inputs) or
any(o.dtype == 'float16' for o in var.owner.outputs)): any(o.dtype == 'float16' for o in var.owner.outputs)):
# flag for elemwise ops to check. # flag for elemwise ops to check.
self.inner_float16 = True self.inner_float16 = True
...@@ -3325,13 +3324,13 @@ class Composite(ScalarOp): ...@@ -3325,13 +3324,13 @@ class Composite(ScalarOp):
name = "V%%(id)s_tmp%i" % i name = "V%%(id)s_tmp%i" % i
subd[output] = name subd[output] = name
_c_code += "%s %s;\n" % ( _c_code += "%s %s;\n" % (
output.type.dtype_specs()[1], name) output.type.dtype_specs()[1], name)
s = node.op.c_code(node, s = node.op.c_code(
self.nodenames[j], node,
[subd[input] for input in node.inputs], self.nodenames[j],
[subd[output] for output in node.outputs], [subd[input] for input in node.inputs],
dict(fail="%(fail)s", [subd[output] for output in node.outputs],
id="%%(id)s_%i" % j)) dict(fail="%(fail)s", id="%%(id)s_%i" % j))
_c_code += s _c_code += s
_c_code += "\n" _c_code += "\n"
_c_code += "}\n" _c_code += "}\n"
...@@ -3454,7 +3453,7 @@ class Composite(ScalarOp): ...@@ -3454,7 +3453,7 @@ class Composite(ScalarOp):
def make_node(self, *inputs): def make_node(self, *inputs):
if (tuple([i.type for i in self.inputs]) == if (tuple([i.type for i in self.inputs]) ==
tuple([i.type for i in inputs])): tuple([i.type for i in inputs])):
return super(Composite, self).make_node(*inputs) return super(Composite, self).make_node(*inputs)
else: else:
# Make a new op with the right input type. # Make a new op with the right input type.
...@@ -3489,7 +3488,7 @@ class Composite(ScalarOp): ...@@ -3489,7 +3488,7 @@ class Composite(ScalarOp):
izip(("o%i" % i for i in xrange(len(onames))), izip(("o%i" % i for i in xrange(len(onames))),
onames)), **sub) onames)), **sub)
d['nodename'] = nodename d['nodename'] = nodename
if not 'id' in sub: if 'id' not in sub:
# The use of a dummy id is safe as the code is in a separate block. # The use of a dummy id is safe as the code is in a separate block.
# It won't generate conflicting variable name. # It won't generate conflicting variable name.
d['id'] = '_DUMMY_ID_' d['id'] = '_DUMMY_ID_'
...@@ -3521,8 +3520,8 @@ class Composite(ScalarOp): ...@@ -3521,8 +3520,8 @@ class Composite(ScalarOp):
for subnode, subnodename in zip(self.fgraph.toposort(), self.nodenames): for subnode, subnodename in zip(self.fgraph.toposort(), self.nodenames):
try: try:
subnode_support_code = subnode.op.c_support_code_apply( subnode_support_code = subnode.op.c_support_code_apply(
subnode, subnode,
subnodename % dict(nodename=name)) subnodename % dict(nodename=name))
if subnode_support_code: if subnode_support_code:
rval.append(subnode_support_code) rval.append(subnode_support_code)
except gof.utils.MethodNotDefined: except gof.utils.MethodNotDefined:
...@@ -3536,9 +3535,9 @@ class Composite(ScalarOp): ...@@ -3536,9 +3535,9 @@ class Composite(ScalarOp):
def __eq__(self, other): def __eq__(self, other):
if self is other: if self is other:
return True return True
if (type(self) != type(other) if (type(self) != type(other) or
or self.nin != other.nin self.nin != other.nin or
or self.nout != other.nout): self.nout != other.nout):
return False return False
# see __hash__ for comment on why there is no mention of fgraph # see __hash__ for comment on why there is no mention of fgraph
# or module cache key here. # or module cache key here.
...@@ -3546,9 +3545,9 @@ class Composite(ScalarOp): ...@@ -3546,9 +3545,9 @@ class Composite(ScalarOp):
def __hash__(self): def __hash__(self):
rval = hash((type(self), rval = hash((type(self),
self.nin, self.nin,
self.nout, self.nout,
self._c_code)) self._c_code))
# Note that in general, the configparser settings at the time # Note that in general, the configparser settings at the time
# of code generation (__init__) affect the semantics of this Op. # of code generation (__init__) affect the semantics of this Op.
# This function assumes that all relevant info about the configparser # This function assumes that all relevant info about the configparser
......
...@@ -296,51 +296,51 @@ class Psi(UnaryScalarOp): ...@@ -296,51 +296,51 @@ class Psi(UnaryScalarOp):
def c_support_code(self): def c_support_code(self):
return ( return (
""" """
// For GPU support // For GPU support
#ifdef __CUDACC__ #ifdef __CUDACC__
#define DEVICE __device__ #define DEVICE __device__
#else #else
#define DEVICE #define DEVICE
#endif #endif
#ifndef _PSIFUNCDEFINED #ifndef _PSIFUNCDEFINED
#define _PSIFUNCDEFINED #define _PSIFUNCDEFINED
DEVICE double _psi(double x){ DEVICE double _psi(double x){
/*taken from /*taken from
Bernardo, J. M. (1976). Algorithm AS 103: Bernardo, J. M. (1976). Algorithm AS 103:
Psi (Digamma) Function. Applied Statistics. 25 (3), 315-317. Psi (Digamma) Function. Applied Statistics. 25 (3), 315-317.
http://www.uv.es/~bernardo/1976AppStatist.pdf */ http://www.uv.es/~bernardo/1976AppStatist.pdf */
double y, R, psi_ = 0; double y, R, psi_ = 0;
double S = 1.0e-5; double S = 1.0e-5;
double C = 8.5; double C = 8.5;
double S3 = 8.333333333e-2; double S3 = 8.333333333e-2;
double S4 = 8.333333333e-3; double S4 = 8.333333333e-3;
double S5 = 3.968253968e-3; double S5 = 3.968253968e-3;
double D1 = -0.5772156649; double D1 = -0.5772156649;
y = x; y = x;
if (y <= 0.0) if (y <= 0.0)
return psi_; return psi_;
if (y <= S ) if (y <= S )
return D1 - 1.0/y; return D1 - 1.0/y;
while (y < C){ while (y < C){
psi_ = psi_ - 1.0 / y; psi_ = psi_ - 1.0 / y;
y = y + 1;} y = y + 1;}
R = 1.0 / y; R = 1.0 / y;
psi_ = psi_ + log(y) - .5 * R ; psi_ = psi_ + log(y) - .5 * R ;
R= R*R; R= R*R;
psi_ = psi_ - R * (S3 - R * (S4 - R * S5)); psi_ = psi_ - R * (S3 - R * (S4 - R * S5));
return psi_;} return psi_;}
#endif #endif
""" ) """)
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
x, = inp x, = inp
......
import numpy as np import itertools as it
from theano.scalar.basic import Apply, ScalarOp, as_scalar, float64, float32, int64 from theano.scalar.basic import Apply, ScalarOp, as_scalar, float64, float32, int64
from theano.gof.utils import remove from theano.gof.utils import remove
imported_sympy = False imported_sympy = False
try: try:
import sympy
from sympy.utilities.codegen import get_default_datatype, codegen from sympy.utilities.codegen import get_default_datatype, codegen
imported_sympy = True imported_sympy = True
except ImportError: except ImportError:
pass pass
import itertools as it names = ("sympy_func_%d" % i for i in it.count(0))
names = ("sympy_func_%d"%i for i in it.count(0))
def include_line(line): def include_line(line):
...@@ -53,8 +51,8 @@ class SymPyCCode(ScalarOp): ...@@ -53,8 +51,8 @@ class SymPyCCode(ScalarOp):
def _sympy_c_code(self): def _sympy_c_code(self):
[(c_name, c_code), (h_name, c_header)] = codegen( [(c_name, c_code), (h_name, c_header)] = codegen(
(self.name, self.expr), 'C', 'project_name', (self.name, self.expr), 'C', 'project_name',
header=False, argument_sequence=self.inputs) header=False, argument_sequence=self.inputs)
return c_code return c_code
def c_support_code(self): def c_support_code(self):
...@@ -64,8 +62,8 @@ class SymPyCCode(ScalarOp): ...@@ -64,8 +62,8 @@ class SymPyCCode(ScalarOp):
def c_headers(self): def c_headers(self):
c_code = self._sympy_c_code() c_code = self._sympy_c_code()
return [line.replace("#include", "").strip() for line in return [line.replace("#include", "").strip() for line in
c_code.split('\n') if include_line(line) c_code.split('\n') if include_line(line) and
and not 'project_name' in line] 'project_name' not in line]
def c_code(self, node, name, input_names, output_names, sub): def c_code(self, node, name, input_names, output_names, sub):
y, = output_names y, = output_names
...@@ -92,7 +90,7 @@ class SymPyCCode(ScalarOp): ...@@ -92,7 +90,7 @@ class SymPyCCode(ScalarOp):
def grad(self, inputs, output_grads): def grad(self, inputs, output_grads):
return [SymPyCCode(self.inputs, return [SymPyCCode(self.inputs,
self.expr.diff(inp), self.expr.diff(inp),
name=self.name+"_grad_%d"%i)(*inputs) name=self.name + "_grad_%d" % i)(*inputs)
for i, inp in enumerate(self.inputs)] for i, inp in enumerate(self.inputs)]
def _info(self): def _info(self):
......
...@@ -14,17 +14,18 @@ default when calling theano.shared(value) then users must really go out of their ...@@ -14,17 +14,18 @@ default when calling theano.shared(value) then users must really go out of their
way (as scan does) to create a shared variable of this kind. way (as scan does) to create a shared variable of this kind.
""" """
__authors__ = "James Bergstra"
__copyright__ = "(c) 2010, Universite de Montreal"
__license__ = "3-clause BSD License"
__contact__ = "theano-dev <theano-dev@googlegroups.com>"
__docformat__ = "restructuredtext en"
import numpy import numpy
from theano.compile import SharedVariable from theano.compile import SharedVariable
from .basic import Scalar, _scalar_py_operators from .basic import Scalar, _scalar_py_operators
__authors__ = "James Bergstra"
__copyright__ = "(c) 2010, Universite de Montreal"
__license__ = "3-clause BSD License"
__contact__ = "theano-dev <theano-dev@googlegroups.com>"
__docformat__ = "restructuredtext en"
class ScalarSharedVariable(_scalar_py_operators, SharedVariable): class ScalarSharedVariable(_scalar_py_operators, SharedVariable):
pass pass
...@@ -41,7 +42,7 @@ def shared(value, name=None, strict=False, allow_downcast=None): ...@@ -41,7 +42,7 @@ def shared(value, name=None, strict=False, allow_downcast=None):
:note: We implement this using 0-d tensors for now. :note: We implement this using 0-d tensors for now.
""" """
if not isinstance (value, (numpy.number, float, int, complex)): if not isinstance(value, (numpy.number, float, int, complex)):
raise TypeError() raise TypeError()
try: try:
dtype = value.dtype dtype = value.dtype
...@@ -52,7 +53,9 @@ def shared(value, name=None, strict=False, allow_downcast=None): ...@@ -52,7 +53,9 @@ def shared(value, name=None, strict=False, allow_downcast=None):
value = getattr(numpy, dtype)(value) value = getattr(numpy, dtype)(value)
scalar_type = Scalar(dtype=dtype) scalar_type = Scalar(dtype=dtype)
rval = ScalarSharedVariable( rval = ScalarSharedVariable(
type=scalar_type, type=scalar_type,
value=value, value=value,
name=name, strict=strict, allow_downcast=allow_downcast) name=name,
strict=strict,
allow_downcast=allow_downcast)
return rval return rval
...@@ -114,11 +114,7 @@ whitelist_flake8 = [ ...@@ -114,11 +114,7 @@ whitelist_flake8 = [
"tensor/nnet/tests/test_conv3d.py", "tensor/nnet/tests/test_conv3d.py",
"tensor/nnet/tests/speed_test_conv.py", "tensor/nnet/tests/speed_test_conv.py",
"tensor/nnet/tests/test_sigm.py", "tensor/nnet/tests/test_sigm.py",
"scalar/sharedvar.py",
"scalar/basic_scipy.py",
"scalar/basic_sympy.py",
"scalar/__init__.py", "scalar/__init__.py",
"scalar/basic.py",
"scalar/tests/test_basic.py", "scalar/tests/test_basic.py",
"sandbox/test_theano_object.py", "sandbox/test_theano_object.py",
"sandbox/test_scan.py", "sandbox/test_scan.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论