提交 cfc931fa authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Allow JAX Reshape to work with constant shape inputs

上级 6c6bf08a
...@@ -11,6 +11,7 @@ from numpy.random.bit_generator import _coerce_to_uint32_array ...@@ -11,6 +11,7 @@ from numpy.random.bit_generator import _coerce_to_uint32_array
from aesara.compile.ops import DeepCopyOp, ViewOp from aesara.compile.ops import DeepCopyOp, ViewOp
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph import Constant
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.ifelse import IfElse from aesara.ifelse import IfElse
from aesara.link.utils import fgraph_to_python from aesara.link.utils import fgraph_to_python
...@@ -728,9 +729,20 @@ def jax_funcify_MakeVector(op, **kwargs): ...@@ -728,9 +729,20 @@ def jax_funcify_MakeVector(op, **kwargs):
@jax_funcify.register(Reshape) @jax_funcify.register(Reshape)
def jax_funcify_Reshape(op, **kwargs): def jax_funcify_Reshape(op, node, **kwargs):
def reshape(x, shape):
return jnp.reshape(x, shape) # JAX reshape only works with constant inputs, otherwise JIT fails
shape = node.inputs[1]
if isinstance(shape, Constant):
constant_shape = shape.data
def reshape(x, _):
return jax.numpy.reshape(x, constant_shape)
else:
def reshape(x, shape):
return jax.numpy.reshape(x, shape)
return reshape return reshape
......
...@@ -863,10 +863,6 @@ def test_jax_MakeVector(): ...@@ -863,10 +863,6 @@ def test_jax_MakeVector():
compare_jax_and_py(x_fg, []) compare_jax_and_py(x_fg, [])
@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"),
reason="Omnistaging cannot be disabled",
)
def test_jax_Reshape(): def test_jax_Reshape():
a = vector("a") a = vector("a")
x = reshape(a, (2, 2)) x = reshape(a, (2, 2))
...@@ -877,16 +873,20 @@ def test_jax_Reshape(): ...@@ -877,16 +873,20 @@ def test_jax_Reshape():
# See https://github.com/tensorflow/probability/commit/782d0c64eb774b9aac54a1c8488e4f1f96fbbc68 # See https://github.com/tensorflow/probability/commit/782d0c64eb774b9aac54a1c8488e4f1f96fbbc68
x = reshape(a, (a.shape[0] // 2, a.shape[0] // 2)) x = reshape(a, (a.shape[0] // 2, a.shape[0] // 2))
x_fg = FunctionGraph([a], [x]) x_fg = FunctionGraph([a], [x])
compare_jax_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)]) with pytest.raises(
TypeError,
match="Shapes must be 1D sequences of concrete values of integer type",
):
compare_jax_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)])
@pytest.mark.xfail(reason="jax.numpy.arange requires concrete inputs")
def test_jax_Reshape_nonconcrete():
a = vector("a")
b = iscalar("b") b = iscalar("b")
x = reshape(a, (b, b)) x = reshape(a, (b, b))
x_fg = FunctionGraph([a, b], [x]) x_fg = FunctionGraph([a, b], [x])
compare_jax_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX), 2]) with pytest.raises(
TypeError,
match="Shapes must be 1D sequences of concrete values of integer type",
):
compare_jax_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX), 2])
def test_jax_Dimshuffle(): def test_jax_Dimshuffle():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论