提交 cfc931fa authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Allow JAX Reshape to work with constant shape inputs

上级 6c6bf08a
......@@ -11,6 +11,7 @@ from numpy.random.bit_generator import _coerce_to_uint32_array
from aesara.compile.ops import DeepCopyOp, ViewOp
from aesara.configdefaults import config
from aesara.graph import Constant
from aesara.graph.fg import FunctionGraph
from aesara.ifelse import IfElse
from aesara.link.utils import fgraph_to_python
......@@ -728,9 +729,20 @@ def jax_funcify_MakeVector(op, **kwargs):
@jax_funcify.register(Reshape)
def jax_funcify_Reshape(op, **kwargs):
def reshape(x, shape):
return jnp.reshape(x, shape)
def jax_funcify_Reshape(op, node, **kwargs):
# JAX reshape only works with constant inputs, otherwise JIT fails
shape = node.inputs[1]
if isinstance(shape, Constant):
constant_shape = shape.data
def reshape(x, _):
return jax.numpy.reshape(x, constant_shape)
else:
def reshape(x, shape):
return jax.numpy.reshape(x, shape)
return reshape
......
......@@ -863,10 +863,6 @@ def test_jax_MakeVector():
compare_jax_and_py(x_fg, [])
@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"),
reason="Omnistaging cannot be disabled",
)
def test_jax_Reshape():
a = vector("a")
x = reshape(a, (2, 2))
......@@ -877,16 +873,20 @@ def test_jax_Reshape():
# See https://github.com/tensorflow/probability/commit/782d0c64eb774b9aac54a1c8488e4f1f96fbbc68
x = reshape(a, (a.shape[0] // 2, a.shape[0] // 2))
x_fg = FunctionGraph([a], [x])
compare_jax_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)])
with pytest.raises(
TypeError,
match="Shapes must be 1D sequences of concrete values of integer type",
):
compare_jax_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)])
@pytest.mark.xfail(reason="jax.numpy.arange requires concrete inputs")
def test_jax_Reshape_nonconcrete():
a = vector("a")
b = iscalar("b")
x = reshape(a, (b, b))
x_fg = FunctionGraph([a, b], [x])
compare_jax_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX), 2])
with pytest.raises(
TypeError,
match="Shapes must be 1D sequences of concrete values of integer type",
):
compare_jax_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX), 2])
def test_jax_Dimshuffle():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论