提交 2c7949b6 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #1526 from nouiz/lamblin-fix_pickle_cache_leak2

fix pickle cache leak
...@@ -14,6 +14,30 @@ from theano.scan_module import scan ...@@ -14,6 +14,30 @@ from theano.scan_module import scan
from theano.tensor.basic import _allclose from theano.tensor.basic import _allclose
# Used in TestComputeTestValue.test_no_perform
class IncOneC(Op):
"""An Op with only a C (c_code) implementation"""
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def make_node(self, input):
input = scalar.as_scalar(input)
output = input.type()
return Apply(self, [input], [output])
def c_code_cache_version(self):
return (1,)
def c_code(self, node, name, inputs, outputs, sub):
x, = inputs
z, = outputs
return "%(z)s = %(x)s + 1;" % locals()
class TestComputeTestValue(unittest.TestCase): class TestComputeTestValue(unittest.TestCase):
def test_variable_only(self): def test_variable_only(self):
...@@ -338,28 +362,6 @@ class TestComputeTestValue(unittest.TestCase): ...@@ -338,28 +362,6 @@ class TestComputeTestValue(unittest.TestCase):
def test_no_perform(self): def test_no_perform(self):
if not theano.config.cxx: if not theano.config.cxx:
raise SkipTest("G++ not available, so we need to skip this test.") raise SkipTest("G++ not available, so we need to skip this test.")
class IncOneC(Op):
"""An Op with only a C (c_code) implementation"""
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def make_node(self, input):
input = scalar.as_scalar(input)
output = input.type()
return Apply(self, [input], [output])
def c_code_cache_version(self):
return (1,)
def c_code(self, node, name, inputs, outputs, sub):
x, = inputs
z, = outputs
return "%(z)s = %(x)s + 1;" % locals()
orig_compute_test_value = theano.config.compute_test_value orig_compute_test_value = theano.config.compute_test_value
try: try:
...@@ -368,6 +370,8 @@ class TestComputeTestValue(unittest.TestCase): ...@@ -368,6 +370,8 @@ class TestComputeTestValue(unittest.TestCase):
i = scalar.int32('i') i = scalar.int32('i')
i.tag.test_value = 3 i.tag.test_value = 3
# Class IncOneC is defined outside of the TestComputeTestValue
# so it can be pickled and unpickled
o = IncOneC()(i) o = IncOneC()(i)
# Check that the perform function is not implemented # Check that the perform function is not implemented
......
...@@ -24,7 +24,7 @@ import theano ...@@ -24,7 +24,7 @@ import theano
from theano.compat import PY3 from theano.compat import PY3
from theano import gof from theano import gof
from theano.gof import (Op, utils, Variable, Constant, Type, Apply, from theano.gof import (Op, utils, Variable, Constant, Type, Apply,
FunctionGraph) FunctionGraph)
from theano.gof.python25 import partial, all, any from theano.gof.python25 import partial, all, any
from theano.configparser import config from theano.configparser import config
...@@ -137,7 +137,7 @@ class Scalar(Type): ...@@ -137,7 +137,7 @@ class Scalar(Type):
py_type = self.dtype_specs()[0] py_type = self.dtype_specs()[0]
if strict and not isinstance(data, py_type): if strict and not isinstance(data, py_type):
raise TypeError("%s expected a %s, got %s of type %s" % ( raise TypeError("%s expected a %s, got %s of type %s" % (
self, py_type, data, type(data)), data) self, py_type, data, type(data)), data)
try: try:
converted_data = py_type(data) converted_data = py_type(data)
if (allow_downcast or if (allow_downcast or
...@@ -148,10 +148,11 @@ class Scalar(Type): ...@@ -148,10 +148,11 @@ class Scalar(Type):
return py_type(data) return py_type(data)
else: else:
raise TypeError('Value cannot accurately be converted to dtype' raise TypeError('Value cannot accurately be converted to dtype'
' (%s) and allow_downcast is not True' % self.dtype) ' (%s) and allow_downcast is not True' %
self.dtype)
except Exception, e: except Exception, e:
raise TypeError("Could not convert %s (value=%s) to %s" % ( raise TypeError("Could not convert %s (value=%s) to %s" % (
type(data), data, self.dtype), e) type(data), data, self.dtype), e)
def values_eq_approx(self, a, b, tolerance=1e-4): def values_eq_approx(self, a, b, tolerance=1e-4):
return abs(a - b) <= ((abs(a) + abs(b)) * tolerance) return abs(a - b) <= ((abs(a) + abs(b)) * tolerance)
...@@ -222,7 +223,7 @@ class Scalar(Type): ...@@ -222,7 +223,7 @@ class Scalar(Type):
}[self.dtype] }[self.dtype]
except KeyError: except KeyError:
raise TypeError("Unsupported dtype for %s: %s" % ( raise TypeError("Unsupported dtype for %s: %s" % (
self.__class__.__name__, self.dtype)) self.__class__.__name__, self.dtype))
def upcast(self, *others): def upcast(self, *others):
return upcast(*[x.dtype for x in [self] + list(others)]) return upcast(*[x.dtype for x in [self] + list(others)])
...@@ -373,11 +374,11 @@ class Scalar(Type): ...@@ -373,11 +374,11 @@ class Scalar(Type):
''' % dict(mytype=mytype, othertype=othertype) ''' % dict(mytype=mytype, othertype=othertype)
operator_eq = ''.join(operator_eq_real(ctype, rtype) operator_eq = ''.join(operator_eq_real(ctype, rtype)
for ctype in cplx_types for ctype in cplx_types
for rtype in real_types) \ for rtype in real_types) \
+ ''.join(operator_eq_cplx(ctype1, ctype2) + ''.join(operator_eq_cplx(ctype1, ctype2)
for ctype1 in cplx_types for ctype1 in cplx_types
for ctype2 in cplx_types) for ctype2 in cplx_types)
# We are not using C++ generic templating here, because this would # We are not using C++ generic templating here, because this would
# generate two different functions for adding a complex64 and a # generate two different functions for adding a complex64 and a
...@@ -395,8 +396,8 @@ class Scalar(Type): ...@@ -395,8 +396,8 @@ class Scalar(Type):
''' % dict(mytype=mytype, othertype=othertype) ''' % dict(mytype=mytype, othertype=othertype)
operator_plus = ''.join(operator_plus_real(ctype, rtype) operator_plus = ''.join(operator_plus_real(ctype, rtype)
for ctype in cplx_types for ctype in cplx_types
for rtype in real_types) for rtype in real_types)
def operator_minus_real(mytype, othertype): def operator_minus_real(mytype, othertype):
return ''' return '''
...@@ -408,8 +409,8 @@ class Scalar(Type): ...@@ -408,8 +409,8 @@ class Scalar(Type):
''' % dict(mytype=mytype, othertype=othertype) ''' % dict(mytype=mytype, othertype=othertype)
operator_minus = ''.join(operator_minus_real(ctype, rtype) operator_minus = ''.join(operator_minus_real(ctype, rtype)
for ctype in cplx_types for ctype in cplx_types
for rtype in real_types) for rtype in real_types)
def operator_mul_real(mytype, othertype): def operator_mul_real(mytype, othertype):
return ''' return '''
...@@ -421,15 +422,15 @@ class Scalar(Type): ...@@ -421,15 +422,15 @@ class Scalar(Type):
''' % dict(mytype=mytype, othertype=othertype) ''' % dict(mytype=mytype, othertype=othertype)
operator_mul = ''.join(operator_mul_real(ctype, rtype) operator_mul = ''.join(operator_mul_real(ctype, rtype)
for ctype in cplx_types for ctype in cplx_types
for rtype in real_types) for rtype in real_types)
return template % dict(nbits=64, half_nbits=32) \ return template % dict(nbits=64, half_nbits=32) \
+ template % dict(nbits=128, half_nbits=64) \ + template % dict(nbits=128, half_nbits=64) \
+ operator_eq \ + operator_eq \
+ operator_plus \ + operator_plus \
+ operator_minus \ + operator_minus \
+ operator_mul + operator_mul
else: else:
return "" return ""
...@@ -448,11 +449,11 @@ class Scalar(Type): ...@@ -448,11 +449,11 @@ class Scalar(Type):
# Register C code for ViewOp on Scalars. # Register C code for ViewOp on Scalars.
theano.compile.register_view_op_c_code( theano.compile.register_view_op_c_code(
Scalar, Scalar,
""" """
%(oname)s = %(iname)s; %(oname)s = %(iname)s;
""", """,
1) 1)
int8 = Scalar('int8') int8 = Scalar('int8')
...@@ -788,17 +789,18 @@ class ScalarOp(Op): ...@@ -788,17 +789,18 @@ class ScalarOp(Op):
if output_types_preference is not None: if output_types_preference is not None:
if not callable(output_types_preference): if not callable(output_types_preference):
raise TypeError( raise TypeError(
"Expected a callable for the 'output_types_preference' argument to %s. (got: %s)" % (self.__class__, output_types_preference)) "Expected a callable for the 'output_types_preference' argument to %s. (got: %s)" %
self.__class__, output_types_preference)
self.output_types_preference = output_types_preference self.output_types_preference = output_types_preference
def make_node(self, *inputs): def make_node(self, *inputs):
if self.nin >= 0: if self.nin >= 0:
if len(inputs) != self.nin: if len(inputs) != self.nin:
raise TypeError("Wrong number of inputs for %s.make_node (got %i(%s), expected %i)" \ raise TypeError("Wrong number of inputs for %s.make_node (got %i(%s), expected %i)" %
% (self, len(inputs), str(inputs), self.nin)) self, len(inputs), str(inputs), self.nin)
inputs = [as_scalar(input) for input in inputs] inputs = [as_scalar(input) for input in inputs]
outputs = [t() for t in self.output_types([input. outputs = [t() for t in self.output_types([input.type
type for input in inputs])] for input in inputs])]
if len(outputs) != self.nout: if len(outputs) != self.nout:
raise TypeError("Not the right number of outputs produced for %s(%s). Expected %s, got %s." raise TypeError("Not the right number of outputs produced for %s(%s). Expected %s, got %s."
% (self, ", ".join(str(input) for input in inputs), self.nout, len(outputs))) % (self, ", ".join(str(input) for input in inputs), self.nout, len(outputs)))
...@@ -906,6 +908,7 @@ class UnaryScalarOp(ScalarOp): ...@@ -906,6 +908,7 @@ class UnaryScalarOp(ScalarOp):
%(fct)s(n, x, z); %(fct)s(n, x, z);
""" % locals() """ % locals()
class BinaryScalarOp(ScalarOp): class BinaryScalarOp(ScalarOp):
# One may define in subclasses the following fields: # One may define in subclasses the following fields:
# - `identity`: for an associative operation, identity corresponds to # - `identity`: for an associative operation, identity corresponds to
...@@ -940,7 +943,7 @@ class FixedLogicalComparison(UnaryScalarOp): ...@@ -940,7 +943,7 @@ class FixedLogicalComparison(UnaryScalarOp):
return [int8] return [int8]
def grad(self, inputs, output_gradients): def grad(self, inputs, output_gradients):
x ,= inputs x, = inputs
out = self(x) out = self(x)
assert str(out.type.dtype).find('int') != -1 assert str(out.type.dtype).find('int') != -1
return [x.zeros_like().astype(theano.config.floatX)] return [x.zeros_like().astype(theano.config.floatX)]
...@@ -1169,8 +1172,9 @@ class BinaryBitOp(BinaryScalarOp): ...@@ -1169,8 +1172,9 @@ class BinaryBitOp(BinaryScalarOp):
return upcast_out(*input_types[0]) return upcast_out(*input_types[0])
def grad(self, inputs, output_gradients): def grad(self, inputs, output_gradients):
a,b = inputs a, b = inputs
return [a.zeros_like().astype(theano.config.floatX), b.zeros_like().astype(theano.config.floatX)] return [a.zeros_like().astype(theano.config.floatX),
b.zeros_like().astype(theano.config.floatX)]
class OR(BinaryBitOp): class OR(BinaryBitOp):
...@@ -1237,7 +1241,7 @@ class Maximum(BinaryScalarOp): ...@@ -1237,7 +1241,7 @@ class Maximum(BinaryScalarOp):
raise NotImplementedError() raise NotImplementedError()
# Test for both y>x and x>=y to detect NaN # Test for both y>x and x>=y to detect NaN
return ('%(z)s = ((%(y)s)>(%(x)s)? (%(y)s): ' return ('%(z)s = ((%(y)s)>(%(x)s)? (%(y)s): '
'((%(x)s)>=(%(y)s)? (%(x)s): nan("")));' % locals()) '((%(x)s)>=(%(y)s)? (%(x)s): nan("")));' % locals())
def grad(self, (x, y), (gz, )): def grad(self, (x, y), (gz, )):
assert gz.type not in complex_types assert gz.type not in complex_types
...@@ -1268,7 +1272,7 @@ class Minimum(BinaryScalarOp): ...@@ -1268,7 +1272,7 @@ class Minimum(BinaryScalarOp):
if any([i.type in complex_types for i in node.inputs]): if any([i.type in complex_types for i in node.inputs]):
raise NotImplementedError() raise NotImplementedError()
return ('%(z)s = ((%(y)s)<(%(x)s)? (%(y)s): ' return ('%(z)s = ((%(y)s)<(%(x)s)? (%(y)s): '
'((%(x)s)<=(%(y)s)? (%(x)s): nan("")));' % locals()) '((%(x)s)<=(%(y)s)? (%(x)s): nan("")));' % locals())
def grad(self, (x, y), (gz, )): def grad(self, (x, y), (gz, )):
assert gz.type not in complex_types assert gz.type not in complex_types
...@@ -1307,7 +1311,7 @@ class Add(ScalarOp): ...@@ -1307,7 +1311,7 @@ class Add(ScalarOp):
for ii, inp in enumerate(inputs): for ii, inp in enumerate(inputs):
if hasattr(inp, 'zeros_like'): if hasattr(inp, 'zeros_like'):
retval.append( retval.append(
inp.zeros_like().astype(theano.config.floatX)) inp.zeros_like().astype(theano.config.floatX))
else: else:
retval.append(grad_undefined(self, ii, inp)) retval.append(grad_undefined(self, ii, inp))
else: else:
...@@ -1342,9 +1346,10 @@ class Mul(ScalarOp): ...@@ -1342,9 +1346,10 @@ class Mul(ScalarOp):
output_type = self.output_types([i.type for i in inputs])[0] output_type = self.output_types([i.type for i in inputs])[0]
if output_type in complex_types: if output_type in complex_types:
if not gz.type in complex_types: if not gz.type in complex_types:
raise TypeError('Mul with output_type ' + str(output_type) +\ raise TypeError(
' expected gz type to be complex, got gz with type ' +\ 'Mul with output_type ' + str(output_type) +
str(gz.type)) ' expected gz type to be complex, got gz with type ' +
str(gz.type))
if output_type in discrete_types: if output_type in discrete_types:
return [ipt.zeros_like().astype(theano.config.floatX) return [ipt.zeros_like().astype(theano.config.floatX)
...@@ -1364,7 +1369,7 @@ class Mul(ScalarOp): ...@@ -1364,7 +1369,7 @@ class Mul(ScalarOp):
retval += [yr * real(gz) + yi * imag(gz)] retval += [yr * real(gz) + yi * imag(gz)]
else: else:
retval += [mul(*([gz] + utils.difference(inputs, retval += [mul(*([gz] + utils.difference(inputs,
[input])))] [input])))]
return retval return retval
...@@ -1420,10 +1425,10 @@ def int_or_true_div(x_discrete, y_discrete): ...@@ -1420,10 +1425,10 @@ def int_or_true_div(x_discrete, y_discrete):
"`x.__truediv__(y)`.") "`x.__truediv__(y)`.")
elif config.int_division == 'int': elif config.int_division == 'int':
warnings.warn( warnings.warn(
"Division of two integer types with x / y is deprecated, " "Division of two integer types with x / y is deprecated, "
"please use x // y for an integer division.", "please use x // y for an integer division.",
DeprecationWarning, DeprecationWarning,
stacklevel=4) stacklevel=4)
return 'int' return 'int'
elif config.int_division == 'floatX': elif config.int_division == 'floatX':
return 'true' return 'true'
...@@ -1493,8 +1498,8 @@ true_div = TrueDiv(upcast_out, name='true_div') ...@@ -1493,8 +1498,8 @@ true_div = TrueDiv(upcast_out, name='true_div')
class IntDiv(BinaryScalarOp): class IntDiv(BinaryScalarOp):
complex_error = ComplexError( complex_error = ComplexError(
"Theano does not support integer division (//) on " "Theano does not support integer division (//) on "
"complex numbers, since numpy deprecated it.") "complex numbers, since numpy deprecated it.")
def impl(self, x, y): def impl(self, x, y):
return x // y return x // y
...@@ -1575,8 +1580,8 @@ def mod_check(x, y): ...@@ -1575,8 +1580,8 @@ def mod_check(x, y):
class Mod(BinaryScalarOp): class Mod(BinaryScalarOp):
complex_error = ComplexError( complex_error = ComplexError(
"Theano does not support the mod operator (%) on " "Theano does not support the mod operator (%) on "
"complex numbers, since numpy deprecated it.") "complex numbers, since numpy deprecated it.")
def impl(self, x, y): def impl(self, x, y):
if isinstance(x, numpy.complex) or isinstance(y, numpy.complex): if isinstance(x, numpy.complex) or isinstance(y, numpy.complex):
...@@ -1874,7 +1879,7 @@ class Abs(UnaryScalarOp): ...@@ -1874,7 +1879,7 @@ class Abs(UnaryScalarOp):
outputs = [float64()] outputs = [float64()]
else: else:
outputs = [t() for t in self.output_types( outputs = [t() for t in self.output_types(
[input.type for input in inputs])] [input.type for input in inputs])]
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
def impl(self, x): def impl(self, x):
...@@ -1989,7 +1994,7 @@ class RoundHalfToEven(UnaryScalarOp): ...@@ -1989,7 +1994,7 @@ class RoundHalfToEven(UnaryScalarOp):
def c_code___(self, node, name, (x, ), (z, ), sub): def c_code___(self, node, name, (x, ), (z, ), sub):
typ = node.outputs[0].type.dtype typ = node.outputs[0].type.dtype
if not node.outputs[0].type.dtype in ['float32', 'float64']: if not typ in ['float32', 'float64']:
Exception("The output should be float32 or float64") Exception("The output should be float32 or float64")
return dedent(""" return dedent("""
...@@ -2036,7 +2041,7 @@ class RoundHalfToEven(UnaryScalarOp): ...@@ -2036,7 +2041,7 @@ class RoundHalfToEven(UnaryScalarOp):
#undef ROUNDING_EPSILON #undef ROUNDING_EPSILON
""") """ % locals())
round_half_to_even = RoundHalfToEven(same_out_float_only) round_half_to_even = RoundHalfToEven(same_out_float_only)
...@@ -2750,15 +2755,14 @@ class Composite(ScalarOp): ...@@ -2750,15 +2755,14 @@ class Composite(ScalarOp):
subd[orphan] = orphan.type.c_literal(orphan.data) subd[orphan] = orphan.type.c_literal(orphan.data)
else: else:
raise ValueError( raise ValueError(
"All orphans in the fgraph to Composite must" "All orphans in the fgraph to Composite must"
" be Constant instances.") " be Constant instances.")
_c_code = "{\n" _c_code = "{\n"
i = 0
j = 0
self.nodenames = ["%(nodename)s_" + ('subnode%i' % j) self.nodenames = ["%(nodename)s_" + ('subnode%i' % j)
for j, n in enumerate(self.fgraph.toposort())] for j, n in enumerate(self.fgraph.toposort())]
i = 0
for j, node in enumerate(self.fgraph.toposort()): for j, node in enumerate(self.fgraph.toposort()):
for output in node.outputs: for output in node.outputs:
if output not in subd: if output not in subd:
...@@ -2835,6 +2839,10 @@ class Composite(ScalarOp): ...@@ -2835,6 +2839,10 @@ class Composite(ScalarOp):
self.fgraph = fgraph self.fgraph = fgraph
def __init__(self, inputs, outputs): def __init__(self, inputs, outputs):
# We need to clone the graph as sometimes its nodes already
# contain a reference to an fgraph. As we want the Composite
# to be pickable, we can't have reference to fgraph.
inputs, outputs = gof.graph.clone(inputs, outputs)
self.inputs = copy(inputs) self.inputs = copy(inputs)
self.outputs = copy(outputs) self.outputs = copy(outputs)
self.inputs_type = tuple([input.type for input in inputs]) self.inputs_type = tuple([input.type for input in inputs])
......
...@@ -12,10 +12,14 @@ If you do want to rewrite these tests, bear in mind: ...@@ -12,10 +12,14 @@ If you do want to rewrite these tests, bear in mind:
import unittest import unittest
import theano import theano
from theano.gof import Variable, Op, FunctionGraph from theano.gof import FunctionGraph
from theano import gof from theano import gof
from theano.scalar.basic import * from theano.scalar.basic import (floats, float32, float64,
ints, int8, int32, complex64,
ComplexError, IntDiv, TrueDiv,
Composite, add, div_proxy,
and_, eq, neq, invert, mul)
def inputs(): def inputs():
...@@ -216,7 +220,7 @@ class test_div(unittest.TestCase): ...@@ -216,7 +220,7 @@ class test_div(unittest.TestCase):
d = float64() d = float64()
f = float32() f = float32()
print (a//b).owner.op #print (a//b).owner.op
assert isinstance((a//b).owner.op, IntDiv) assert isinstance((a//b).owner.op, IntDiv)
assert isinstance((b//a).owner.op, IntDiv) assert isinstance((b//a).owner.op, IntDiv)
assert isinstance((b/d).owner.op, TrueDiv) assert isinstance((b/d).owner.op, TrueDiv)
......
...@@ -880,6 +880,56 @@ class T_using_gpu(unittest.TestCase): ...@@ -880,6 +880,56 @@ class T_using_gpu(unittest.TestCase):
for x in f.maker.fgraph.toposort()]) for x in f.maker.fgraph.toposort()])
# Used in T_fibby
class Fibby(theano.Op):
"""
An arbitrarily generalized Fibbonacci sequence
"""
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def make_node(self, x):
x_ = theano.tensor.as_tensor_variable(x)
assert x_.ndim == 1
return theano.Apply(self,
inputs=[x_],
outputs=[x_.type()])
# using x_.type() is dangerous, it copies x's broadcasting
# behaviour
def perform(self, node, inputs, output_storage):
x, = inputs
y = output_storage[0][0] = x.copy()
for i in range(2, len(x)):
y[i] = y[i - 1] * y[i - 2] + x[i]
def c_code(self, node, name, inames, onames, sub):
x, = inames
y, = onames
fail = sub['fail']
return """
Py_XDECREF(%(y)s);
%(y)s = (PyArrayObject*)PyArray_FromArray(
%(x)s, 0, NPY_ARRAY_ENSURECOPY);
if (!%(y)s)
%(fail)s;
{//New scope needed to make compilation work
dtype_%(y)s * y = (dtype_%(y)s*)%(y)s->data;
dtype_%(x)s * x = (dtype_%(x)s*)%(x)s->data;
for (int i = 2; i < %(x)s->dimensions[0]; ++i)
y[i] = y[i-1]*y[i-2] + x[i];
}
""" % locals()
def c_code_cache_version(self):
return (1,)
class T_fibby(unittest.TestCase): class T_fibby(unittest.TestCase):
## All tests here belong to ## All tests here belong to
## http://deeplearning.net/software/theano/extending/fibby.html ## http://deeplearning.net/software/theano/extending/fibby.html
...@@ -888,54 +938,8 @@ class T_fibby(unittest.TestCase): ...@@ -888,54 +938,8 @@ class T_fibby(unittest.TestCase):
def test_fibby_1(self): def test_fibby_1(self):
class Fibby(theano.Op): # The definition of class Fibby is done outside of the test,
# so the object can be pickled.
"""
An arbitrarily generalized Fibbonacci sequence
"""
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def make_node(self, x):
x_ = theano.tensor.as_tensor_variable(x)
assert x_.ndim == 1
return theano.Apply(self,
inputs=[x_],
outputs=[x_.type()])
# using x_.type() is dangerous, it copies x's broadcasting
# behaviour
def perform(self, node, inputs, output_storage):
x, = inputs
y = output_storage[0][0] = x.copy()
for i in range(2, len(x)):
y[i] = y[i - 1] * y[i - 2] + x[i]
def c_code(self, node, name, inames, onames, sub):
x, = inames
y, = onames
fail = sub['fail']
return """
Py_XDECREF(%(y)s);
%(y)s = (PyArrayObject*)PyArray_FromArray(
%(x)s, 0, NPY_ARRAY_ENSURECOPY);
if (!%(y)s)
%(fail)s;
{//New scope needed to make compilation work
dtype_%(y)s * y = (dtype_%(y)s*)%(y)s->data;
dtype_%(x)s * x = (dtype_%(x)s*)%(x)s->data;
for (int i = 2; i < %(x)s->dimensions[0]; ++i)
y[i] = y[i-1]*y[i-2] + x[i];
}
""" % locals()
def c_code_cache_version(self):
return (1,)
fibby = Fibby() fibby = Fibby()
from theano.tensor.opt import (get_scalar_constant_value, from theano.tensor.opt import (get_scalar_constant_value,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论