提交 21dac033 authored 作者: Matthew Rocklin's avatar Matthew Rocklin

small edits to SympyCCode and test

Change grad function name to "grad", not "prime" Change y**2 test to y**3
上级 649a41db
......@@ -88,7 +88,7 @@ class SymPyCCode(ScalarOp):
def grad(self, inputs, output_grads):
return [SymPyCCode(self.inputs,
self.expr.diff(inp),
name=self.name+"_prime_%d"%i)(*inputs)
name=self.name+"_grad_%d"%i)(*inputs)
for i, inp in enumerate(self.inputs)]
def _info(self):
......
......@@ -26,8 +26,8 @@ def test_grad():
assert ztprime.owner.op.expr == 2*xs
def test_multivar_grad():
op = SymPyCCode([xs, ys], xs**2 + ys**2)
op = SymPyCCode([xs, ys], xs**2 + ys**3)
zt = op(xt, yt)
dzdx, dzdy = theano.grad(zt, [xt, yt])
assert dzdx.owner.op.expr == 2*xs
assert dzdy.owner.op.expr == 2*ys
assert dzdy.owner.op.expr == 3*ys**2
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论