提交 a62dec23 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Add JAX conversions for dot and arange

上级 02c02d72
...@@ -500,3 +500,35 @@ def test_nnet(): ...@@ -500,3 +500,35 @@ def test_nnet():
out = tt.nnet.softplus(x) out = tt.nnet.softplus(x)
fgraph = theano.gof.FunctionGraph([x], [out]) fgraph = theano.gof.FunctionGraph([x], [out])
_ = compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) _ = compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_tensor_basics():
y = tt.vector("y")
y.tag.test_value = np.r_[1.0, 2.0].astype(theano.config.floatX)
x = tt.vector("x")
x.tag.test_value = np.r_[3.0, 4.0].astype(theano.config.floatX)
A = tt.matrix("A")
A.tag.test_value = np.empty((2, 2), dtype=theano.config.floatX)
alpha = tt.scalar("alpha")
alpha.tag.test_value = np.array(3.0, dtype=theano.config.floatX)
beta = tt.scalar("beta")
beta.tag.test_value = np.array(5.0, dtype=theano.config.floatX)
# This should be converted into a `Gemv` `Op` when the non-JAX compatible
# optimizations are turned on; however, when using JAX mode, it should
# leave the expression alone.
out = y.dot(alpha * A).dot(x) + beta * y
fgraph = theano.gof.FunctionGraph([y, x, A, alpha, beta], [out])
_ = compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
@pytest.mark.xfail(reason="jax.numpy.arange requires concrete inputs")
def test_arange():
a = tt.scalar("a")
a.tag.test_value = 10
out = tt.arange(a)
fgraph = theano.gof.FunctionGraph([a], [out])
_ = compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
...@@ -26,6 +26,8 @@ from theano.tensor.subtensor import ( ...@@ -26,6 +26,8 @@ from theano.tensor.subtensor import (
from theano.scan_module.scan_op import Scan from theano.scan_module.scan_op import Scan
from theano.scan_module.scan_utils import scan_args as ScanArgs from theano.scan_module.scan_utils import scan_args as ScanArgs
from theano.tensor.basic import ( from theano.tensor.basic import (
Dot,
ARange,
TensorFromScalar, TensorFromScalar,
ScalarFromTensor, ScalarFromTensor,
AllocEmpty, AllocEmpty,
...@@ -198,6 +200,23 @@ def jax_funcify_Alloc(op): ...@@ -198,6 +200,23 @@ def jax_funcify_Alloc(op):
return alloc return alloc
@jax_funcify.register(Dot)
def jax_funcify_Dot(op):
def dot(x, y):
return jnp.dot(x, y)
return dot
@jax_funcify.register(ARange)
def jax_funcify_ARange(op):
# XXX: This currently requires concrete arguments.
def arange(start, stop, step):
return jnp.arange(start, stop, step, dtype=op.dtype)
return arange
def jnp_safe_copy(x): def jnp_safe_copy(x):
try: try:
res = jnp.copy(x) res = jnp.copy(x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论