提交 d38552d9 authored 作者: danhphan's avatar danhphan 提交者: Brandon T. Willard

Add JAX implementation for BroadcastTo

上级 f7a506ff
......@@ -37,6 +37,7 @@ from aesara.tensor.blas import BatchedDot
from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from aesara.tensor.extra_ops import (
Bartlett,
BroadcastTo,
CumOp,
DiffOp,
FillDiagonal,
......@@ -1157,3 +1158,11 @@ def jax_funcify_Psi(op, node, **kwargs):
return jax.scipy.special.digamma(x)
return psi
@jax_funcify.register(BroadcastTo)
def jax_funcify_BroadcastTo(op, **kwargs):
def broadcast_to(x, *shape):
return jnp.broadcast_to(x, shape)
return broadcast_to
......@@ -1214,6 +1214,34 @@ def test_extra_ops():
)
def set_test_value(x, v):
x.tag.test_value = v
return x
@pytest.mark.parametrize(
"x, shape",
[
(
set_test_value(
vector("x"), np.random.random(size=(2,)).astype(config.floatX)
),
[at.as_tensor(3, dtype=np.int64), at.as_tensor(2, dtype=np.int64)],
),
(
set_test_value(
vector("x"), np.random.random(size=(2,)).astype(config.floatX)
),
[at.as_tensor(3, dtype=np.int8), at.as_tensor(2, dtype=np.int64)],
),
],
)
def test_BroadcastTo(x, shape):
out = at_extra_ops.broadcast_to(x, shape)
fgraph = FunctionGraph(outputs=[out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"),
reason="Omnistaging cannot be disabled",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论