提交 76bee524 authored 作者: Matthew Rocklin's avatar Matthew Rocklin

add SymPyCCode.grad

上级 79fc80bf
......@@ -86,6 +86,12 @@ class SymPyCCode(ScalarOp):
def perform(self, node, inputs, output_storage):
raise NotImplementedError()
def grad(self, inputs, output_grads):
return [SymPyCCode(self.inputs,
self.expr.diff(inp),
name=self.name+"_prime_%d"%i)(*inputs)
for i, inp in enumerate(self.inputs)]
def _info(self):
return type(self), self.name, tuple(self.inputs), self.expr
......
from theano.scalar.basic_sympy import SymPyCCode
from theano.scalar.basic import floats
from theano.gof import FunctionGraph
from theano import gof
import theano
import sympy
xs = sympy.Symbol('x')
ys = sympy.Symbol('y')
xt, yt = floats('xy')
def test_SymPyCCode():
xs = sympy.Symbol('x')
ys = sympy.Symbol('y')
op = SymPyCCode([xs, ys], xs + ys)
xt, yt = floats('xy')
e = op(xt, yt)
g = FunctionGraph([xt, yt], [e])
fn = gof.CLinker().accept(g).make_function()
g = theano.gof.FunctionGraph([xt, yt], [e])
fn = theano.gof.CLinker().accept(g).make_function()
assert fn(1.0, 2.0) == 3.0
def test_grad():
op = SymPyCCode([xs], xs**2)
zt = op(xt)
ztprime = theano.grad(zt, xt)
assert ztprime.owner.op.expr == 2*xs
def test_multivar_grad():
op = SymPyCCode([xs, ys], xs**2 + ys**2)
zt = op(xt, yt)
dzdx, dzdy = theano.grad(zt, [xt, yt])
assert dzdx.owner.op.expr == 2*xs
assert dzdy.owner.op.expr == 2*ys
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论