提交 75485c13 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Simplify BatchedDot implementation

The Op now always expects rank 3 inputs, and any dimshuffles are added explicitly by the helper function
上级 18f245fa
...@@ -99,9 +99,7 @@ def jax_funcify_BatchedDot(op, **kwargs): ...@@ -99,9 +99,7 @@ def jax_funcify_BatchedDot(op, **kwargs):
def batched_dot(a, b): def batched_dot(a, b):
if a.shape[0] != b.shape[0]: if a.shape[0] != b.shape[0]:
raise TypeError("Shapes must match in the 0-th dimension") raise TypeError("Shapes must match in the 0-th dimension")
if a.ndim == 2 or b.ndim == 2: return jnp.matmul(a, b)
return jnp.einsum("n...j,nj...->n...", a, b)
return jnp.einsum("nij,njk->nik", a, b)
return batched_dot return batched_dot
......
...@@ -895,6 +895,8 @@ def numba_funcify_BatchedDot(op, node, **kwargs): ...@@ -895,6 +895,8 @@ def numba_funcify_BatchedDot(op, node, **kwargs):
@numba_njit @numba_njit
def batched_dot(x, y): def batched_dot(x, y):
# Numba does not support 3D matmul
# https://github.com/numba/numba/issues/3804
shape = x.shape[:-1] + y.shape[2:] shape = x.shape[:-1] + y.shape[2:]
z0 = np.empty(shape, dtype=dtype) z0 = np.empty(shape, dtype=dtype)
for i in range(z0.shape[0]): for i in range(z0.shape[0]):
......
差异被折叠。
...@@ -43,15 +43,6 @@ def test_jax_BatchedDot(): ...@@ -43,15 +43,6 @@ def test_jax_BatchedDot():
with pytest.raises(TypeError): with pytest.raises(TypeError):
pytensor_jax_fn(*inputs) pytensor_jax_fn(*inputs)
# matrix . matrix
a = matrix("a")
a.tag.test_value = np.linspace(-1, 1, 5 * 3).astype(config.floatX).reshape((5, 3))
b = matrix("b")
b.tag.test_value = np.linspace(1, -1, 5 * 3).astype(config.floatX).reshape((5, 3))
out = at_blas.BatchedDot()(a, b)
fgraph = FunctionGraph([a, b], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_jax_basic_multiout(): def test_jax_basic_multiout():
rng = np.random.default_rng(213234) rng = np.random.default_rng(213234)
......
...@@ -843,23 +843,23 @@ def test_Softplus(x, exc): ...@@ -843,23 +843,23 @@ def test_Softplus(x, exc):
[ [
( (
set_test_value( set_test_value(
at.dmatrix(), at.dtensor3(),
rng.random(size=(3, 3)).astype("float64"), rng.random(size=(2, 3, 3)).astype("float64"),
), ),
set_test_value( set_test_value(
at.dmatrix(), at.dtensor3(),
rng.random(size=(3, 3)).astype("float64"), rng.random(size=(2, 3, 3)).astype("float64"),
), ),
None, None,
), ),
( (
set_test_value( set_test_value(
at.dmatrix(), at.dtensor3(),
rng.random(size=(3, 3)).astype("float64"), rng.random(size=(2, 3, 3)).astype("float64"),
), ),
set_test_value( set_test_value(
at.lmatrix(), at.ltensor3(),
rng.poisson(size=(3, 3)).astype("int64"), rng.poisson(size=(2, 3, 3)).astype("int64"),
), ),
None, None,
), ),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论