提交 7fc8a0b2 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add a test for jax conversion of theano.scalar.basic.Identity

上级 e7e92dce
...@@ -538,3 +538,12 @@ def test_arange(): ...@@ -538,3 +538,12 @@ def test_arange():
out = tt.arange(a) out = tt.arange(a)
fgraph = theano.gof.FunctionGraph([a], [out]) fgraph = theano.gof.FunctionGraph([a], [out])
_ = compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) _ = compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_identity():
a = tt.scalar("a")
a.tag.test_value = 10
out = theano.scalar.basic.identity(a)
fgraph = theano.gof.FunctionGraph([a], [out])
_ = compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论