提交 d734be74 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #1486 from mrocklin/sympy-ccode-op

SymPyCCode op
import theano import theano
from theano.gof.utils import give_variables_names, unique from theano.gof.utils import give_variables_names, unique, remove
from theano.gof.python25 import all from theano.gof.python25 import all
...@@ -34,3 +34,8 @@ def test_give_variables_names_small(): ...@@ -34,3 +34,8 @@ def test_give_variables_names_small():
give_variables_names(fgraph.variables) give_variables_names(fgraph.variables)
assert all(var.name for var in fgraph.variables) assert all(var.name for var in fgraph.variables)
assert unique([var.name for var in fgraph.variables]) assert unique([var.name for var in fgraph.variables])
def test_remove():
even = lambda x: x % 2 == 0
odd = lambda x: x % 2 == 1
assert remove(even, range(5)) == filter(odd, range(5))
...@@ -410,3 +410,15 @@ def give_variables_names(variables): ...@@ -410,3 +410,15 @@ def give_variables_names(variables):
"Maybe you've named some of the variables identically") "Maybe you've named some of the variables identically")
return variables return variables
def remove(predicate, coll):
""" Return those items of collection for which predicate(item) is true.
>>> from itertoolz import remove
>>> def even(x):
... return x % 2 == 0
>>> remove(even, [1, 2, 3, 4])
[1, 3]
"""
return filter(lambda x: not predicate(x), coll)
import numpy as np
from theano.scalar.basic import Apply, ScalarOp, as_scalar, float64, float32, int64
from theano.gof.utils import remove
imported_sympy = False
try:
import sympy
from sympy.utilities.codegen import get_default_datatype, codegen
imported_sympy = True
except ImportError:
pass
import itertools as it
names = ("sympy_func_%d"%i for i in it.count(0))
def include_line(line):
return '#include' in line
def sympy_dtype(expr):
return get_default_datatype(expr).cname
def theano_dtype(expr):
return {'double': float64,
'float': float32,
'int': int64}[sympy_dtype(expr)]
class SymPyCCode(ScalarOp):
""" An Operator that wraps SymPy's C code generation
>>> from sympy.abc import x, y # SymPy Variables
>>> from theano.scalar.basic_sympy import SymPyCCode
>>> op = SymPyCCode([x, y], x + y)
>>> from theano.scalar.basic import floats
>>> xt, yt = floats('xy') # Theano variables
>>> zt = op(xt, yt)
>>> import theano
>>> f = theano.function([xt, yt], zt)
>>> f(1.0, 2.0)
3.0
"""
def __init__(self, inputs, expr, name=None):
self.name = name or next(names)
self.inputs = inputs
self.expr = expr
def _sympy_c_code(self):
[(c_name, c_code), (h_name, c_header)] = codegen(
(self.name, self.expr), 'C', 'project_name',
header=False, argument_sequence=self.inputs)
return c_code
def c_support_code(self):
c_code = self._sympy_c_code()
return '\n'.join(remove(include_line, c_code.split('\n')))
def c_headers(self):
c_code = self._sympy_c_code()
return [line.replace("#include", "").strip() for line in
c_code.split('\n') if include_line(line)
and not 'project_name' in line]
def c_code(self, node, name, input_names, output_names, sub):
y, = output_names
xs = ', '.join(input_names)
f = self.name
return "%(y)s = %(f)s(%(xs)s);" % locals()
def output_types_preference(self, *inputs):
return [theano_dtype(self.expr)]
def make_node(self, *inputs):
# TODO: assert input types are correct use get_default_datatype
if len(inputs) != len(self.inputs):
raise TypeError("Wrong number of inputs for %s.make_node (got %i(%s), expected %i)" % (self, len(inputs), str(inputs), self.nin))
inputs = [as_scalar(input) for input in inputs]
outputs = [t() for t in self.output_types([input.type for input in inputs])]
return Apply(self, inputs, outputs)
def perform(self, node, inputs, output_storage):
raise NotImplementedError()
def grad(self, inputs, output_grads):
return [SymPyCCode(self.inputs,
self.expr.diff(inp),
name=self.name+"_grad_%d"%i)(*inputs)
for i, inp in enumerate(self.inputs)]
def _info(self):
return type(self), self.name, tuple(self.inputs), self.expr
def __eq__(self, other):
return type(self) == type(other) and self._info() == other._info()
def __hash__(self):
return hash(self._info())
from theano.scalar.basic_sympy import SymPyCCode
from theano.scalar.basic import floats
import theano
try:
import sympy
xs = sympy.Symbol('x')
ys = sympy.Symbol('y')
except ImportError:
from nose.plugins.skip import SkipTest
raise SkipTest('optional package sympy disabled')
xt, yt = floats('xy')
def test_SymPyCCode():
op = SymPyCCode([xs, ys], xs + ys)
e = op(xt, yt)
g = theano.gof.FunctionGraph([xt, yt], [e])
fn = theano.gof.CLinker().accept(g).make_function()
assert fn(1.0, 2.0) == 3.0
def test_grad():
op = SymPyCCode([xs], xs**2)
zt = op(xt)
ztprime = theano.grad(zt, xt)
assert ztprime.owner.op.expr == 2*xs
def test_multivar_grad():
op = SymPyCCode([xs, ys], xs**2 + ys**3)
zt = op(xt, yt)
dzdx, dzdy = theano.grad(zt, [xt, yt])
assert dzdx.owner.op.expr == 2*xs
assert dzdy.owner.op.expr == 3*ys**2
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论