提交 fa5a65c0 authored 作者: Matthew Rocklin's avatar Matthew Rocklin

protect sympy tests if no sympy

上级 76bee524
......@@ -12,7 +12,6 @@ except ImportError:
pass
import itertools as it
names = ("sympy_func_%d"%i for i in it.count(0))
def include_line(line):
......
from theano.scalar.basic_sympy import SymPyCCode
from theano.scalar.basic import floats
import theano
import sympy
xs = sympy.Symbol('x')
ys = sympy.Symbol('y')
try:
import sympy
xs = sympy.Symbol('x')
ys = sympy.Symbol('y')
except ImportError:
sympy = False
xt, yt = floats('xy')
def test_SymPyCCode():
if not sympy: return
op = SymPyCCode([xs, ys], xs + ys)
e = op(xt, yt)
g = theano.gof.FunctionGraph([xt, yt], [e])
......@@ -15,12 +20,14 @@ def test_SymPyCCode():
assert fn(1.0, 2.0) == 3.0
def test_grad():
if not sympy: return
op = SymPyCCode([xs], xs**2)
zt = op(xt)
ztprime = theano.grad(zt, xt)
assert ztprime.owner.op.expr == 2*xs
def test_multivar_grad():
if not sympy: return
op = SymPyCCode([xs, ys], xs**2 + ys**2)
zt = op(xt, yt)
dzdx, dzdy = theano.grad(zt, [xt, yt])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论