Unverified 提交 abde1a16 authored 作者: Dan F-M's avatar Dan F-M

adding JAX conversion for BatchedDot

BatchedDot only supports tensor3 fixing broadcasting behavior in BatchedDot adding test for TypeError on dimension mismatch removing extra tests
上级 b0b34b59
......@@ -816,6 +816,42 @@ def test_second():
compare_jax_and_py(fgraph, [np.zeros([5], dtype=theano.config.floatX), 5.0])
def test_jax_BatchedDot():
# tensor3 . tensor3
a = tt.tensor3("a")
a.tag.test_value = (
np.linspace(-1, 1, 10 * 5 * 3).astype(theano.config.floatX).reshape((10, 5, 3))
)
b = tt.tensor3("b")
b.tag.test_value = (
np.linspace(1, -1, 10 * 3 * 2).astype(theano.config.floatX).reshape((10, 3, 2))
)
out = tt.blas.BatchedDot()(a, b)
fgraph = theano.gof.FunctionGraph([a, b], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
# A dimension mismatch should raise a TypeError for compatibility
inputs = [get_test_value(a)[:-1], get_test_value(b)]
opts = theano.gof.Query(include=[None], exclude=["cxx_only", "BlasOpt"])
jax_mode = theano.compile.mode.Mode(theano.sandbox.jax_linker.JAXLinker(), opts)
theano_jax_fn = theano.function(fgraph.inputs, fgraph.outputs, mode=jax_mode)
with pytest.raises(TypeError):
theano_jax_fn(*inputs)
# matrix . matrix
a = tt.matrix("a")
a.tag.test_value = (
np.linspace(-1, 1, 5 * 3).astype(theano.config.floatX).reshape((5, 3))
)
b = tt.matrix("b")
b.tag.test_value = (
np.linspace(1, -1, 5 * 3).astype(theano.config.floatX).reshape((5, 3))
)
out = tt.blas.BatchedDot()(a, b)
fgraph = theano.gof.FunctionGraph([a, b], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_shared():
a = theano.shared(np.array([1, 2, 3], dtype=theano.config.floatX))
......
......@@ -34,6 +34,7 @@ from theano.tensor.basic import (
ScalarFromTensor,
TensorFromScalar,
)
from theano.tensor.blas import BatchedDot
from theano.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from theano.tensor.extra_ops import (
Bartlett,
......@@ -1066,3 +1067,15 @@ def jax_funcify_Eye(op):
return jnp.eye(N, M, k, dtype=dtype)
return eye
@jax_funcify.register(BatchedDot)
def jax_funcify_BatchedDot(op):
def batched_dot(a, b):
if a.shape[0] != b.shape[0]:
raise TypeError("Shapes must match in the 0-th dimension")
if a.ndim == 2 or b.ndim == 2:
return jnp.einsum("n...j,nj...->n...", a, b)
return jnp.einsum("nij,njk->nik", a, b)
return batched_dot
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论