Unverified 提交 454ae317 authored 作者: Junpeng Lao's avatar Junpeng Lao 提交者: GitHub

Implement a Jax conversion for the Second Op (#185)

上级 30be6347
...@@ -696,6 +696,20 @@ def test_identity(): ...@@ -696,6 +696,20 @@ def test_identity():
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_second():
a0 = tt.scalar("a0")
b = tt.scalar("b")
out = theano.scalar.basic.second(a0, b)
fgraph = theano.gof.FunctionGraph([a0, b], [out])
compare_jax_and_py(fgraph, [10.0, 5.0])
a1 = tt.vector("a1")
out = tt.second(a1, b)
fgraph = theano.gof.FunctionGraph([a1, b], [out])
compare_jax_and_py(fgraph, [np.zeros([5], dtype=theano.config.floatX), 5.0])
def test_shared(): def test_shared():
a = theano.shared(np.array([1, 2, 3], dtype=theano.config.floatX)) a = theano.shared(np.array([1, 2, 3], dtype=theano.config.floatX))
......
...@@ -19,7 +19,7 @@ from theano.compile.ops import ( ...@@ -19,7 +19,7 @@ from theano.compile.ops import (
) )
from theano.gof import FunctionGraph from theano.gof import FunctionGraph
from theano.ifelse import IfElse from theano.ifelse import IfElse
from theano.scalar.basic import Cast, Clip, Composite, Identity, ScalarOp from theano.scalar.basic import Cast, Clip, Composite, Identity, ScalarOp, Second
from theano.scan.op import Scan from theano.scan.op import Scan
from theano.scan.utils import scan_args as ScanArgs from theano.scan.utils import scan_args as ScanArgs
from theano.tensor.basic import ( from theano.tensor.basic import (
...@@ -256,6 +256,14 @@ def jax_funcify_ScalarSoftplus(op): ...@@ -256,6 +256,14 @@ def jax_funcify_ScalarSoftplus(op):
return scalarsoftplus return scalarsoftplus
@jax_funcify.register(Second)
def jax_funcify_Second(op):
def second(x, y):
return jnp.broadcast_to(y, x.shape)
return second
@jax_funcify.register(AllocEmpty) @jax_funcify.register(AllocEmpty)
def jax_funcify_AllocEmpty(op): def jax_funcify_AllocEmpty(op):
def allocempty(*shape): def allocempty(*shape):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论