提交 75485c13 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Simplify BatchedDot implementation

The Op now always expects rank 3 inputs, and any dimshuffles are added explicitly by the helper function
上级 18f245fa
......@@ -99,9 +99,7 @@ def jax_funcify_BatchedDot(op, **kwargs):
def batched_dot(a, b):
if a.shape[0] != b.shape[0]:
raise TypeError("Shapes must match in the 0-th dimension")
if a.ndim == 2 or b.ndim == 2:
return jnp.einsum("n...j,nj...->n...", a, b)
return jnp.einsum("nij,njk->nik", a, b)
return jnp.matmul(a, b)
return batched_dot
......
......@@ -895,6 +895,8 @@ def numba_funcify_BatchedDot(op, node, **kwargs):
@numba_njit
def batched_dot(x, y):
# Numba does not support 3D matmul
# https://github.com/numba/numba/issues/3804
shape = x.shape[:-1] + y.shape[2:]
z0 = np.empty(shape, dtype=dtype)
for i in range(z0.shape[0]):
......
差异被折叠。
......@@ -43,15 +43,6 @@ def test_jax_BatchedDot():
with pytest.raises(TypeError):
pytensor_jax_fn(*inputs)
# matrix . matrix
a = matrix("a")
a.tag.test_value = np.linspace(-1, 1, 5 * 3).astype(config.floatX).reshape((5, 3))
b = matrix("b")
b.tag.test_value = np.linspace(1, -1, 5 * 3).astype(config.floatX).reshape((5, 3))
out = at_blas.BatchedDot()(a, b)
fgraph = FunctionGraph([a, b], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_jax_basic_multiout():
rng = np.random.default_rng(213234)
......
......@@ -843,23 +843,23 @@ def test_Softplus(x, exc):
[
(
set_test_value(
at.dmatrix(),
rng.random(size=(3, 3)).astype("float64"),
at.dtensor3(),
rng.random(size=(2, 3, 3)).astype("float64"),
),
set_test_value(
at.dmatrix(),
rng.random(size=(3, 3)).astype("float64"),
at.dtensor3(),
rng.random(size=(2, 3, 3)).astype("float64"),
),
None,
),
(
set_test_value(
at.dmatrix(),
rng.random(size=(3, 3)).astype("float64"),
at.dtensor3(),
rng.random(size=(2, 3, 3)).astype("float64"),
),
set_test_value(
at.lmatrix(),
rng.poisson(size=(3, 3)).astype("int64"),
at.ltensor3(),
rng.poisson(size=(2, 3, 3)).astype("int64"),
),
None,
),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论