提交 8cf3b20f authored 作者: Rémi Louf's avatar Rémi Louf 提交者: Ricardo Vieira

Refactor the JAX implementation of `Reshape`

上级 4235ccc3
...@@ -28,11 +28,38 @@ def jax_funcify_JAXShapeTuple(op, **kwargs): ...@@ -28,11 +28,38 @@ def jax_funcify_JAXShapeTuple(op, **kwargs):
return shape_tuple_fn return shape_tuple_fn
SHAPE_NOT_COMPATIBLE = """JAX requires concrete values for the `shape` parameter of `jax.numpy.reshape`.
Concrete values are either constants:
>>> import pytensor.tensor as at
>>> x = at.ones(6)
>>> y = x.reshape((2, 3))
Or the shape of an array:
>>> mat = at.matrix('mat')
>>> y = x.reshape(mat.shape)
"""
def assert_shape_argument_jax_compatible(shape):
"""Assert whether the current node can be JIT-compiled by JAX.
JAX can JIT-compile functions with a `shape` or `size` argument if it is
given a concrete value, i.e. either a constant or the shape of any traced
value.
"""
shape_op = shape.owner.op
if not isinstance(shape_op, (Shape, Shape_i, JAXShapeTuple)):
raise NotImplementedError(SHAPE_NOT_COMPATIBLE)
@jax_funcify.register(Reshape) @jax_funcify.register(Reshape)
def jax_funcify_Reshape(op, node, **kwargs): def jax_funcify_Reshape(op, node, **kwargs):
# JAX reshape only works with constant inputs, otherwise JIT fails
shape = node.inputs[1] shape = node.inputs[1]
if isinstance(shape, Constant): if isinstance(shape, Constant):
constant_shape = shape.data constant_shape = shape.data
...@@ -40,6 +67,7 @@ def jax_funcify_Reshape(op, node, **kwargs): ...@@ -40,6 +67,7 @@ def jax_funcify_Reshape(op, node, **kwargs):
return jnp.reshape(x, constant_shape) return jnp.reshape(x, constant_shape)
else: else:
assert_shape_argument_jax_compatible(shape)
def reshape(x, shape): def reshape(x, shape):
return jnp.reshape(x, shape) return jnp.reshape(x, shape)
......
import pytensor.tensor as at
from pytensor.compile import optdb from pytensor.compile import optdb
from pytensor.graph.rewriting.basic import in2out, node_rewriter from pytensor.graph.rewriting.basic import in2out, node_rewriter
from pytensor.tensor.var import TensorVariable from pytensor.tensor.basic import MakeVector
import pytensor.tensor as at from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.subtensor import AdvancedIncSubtensor, AdvancedSubtensor
from pytensor.tensor.math import Sum from pytensor.tensor.math import Sum
from pytensor.tensor.shape import Reshape
from pytensor.tensor.subtensor import AdvancedIncSubtensor, AdvancedSubtensor
from pytensor.tensor.var import TensorVariable
@node_rewriter([AdvancedIncSubtensor]) @node_rewriter([AdvancedIncSubtensor])
...@@ -24,7 +27,7 @@ def boolean_indexing_set_or_inc(fgraph, node): ...@@ -24,7 +27,7 @@ def boolean_indexing_set_or_inc(fgraph, node):
if not isinstance(cond, TensorVariable): if not isinstance(cond, TensorVariable):
return return
if not cond.type.dtype == 'bool': if not cond.type.dtype == "bool":
return return
if op.set_instead_of_inc: if op.set_instead_of_inc:
...@@ -36,7 +39,10 @@ def boolean_indexing_set_or_inc(fgraph, node): ...@@ -36,7 +39,10 @@ def boolean_indexing_set_or_inc(fgraph, node):
optdb.register( optdb.register(
"jax_boolean_indexing_set_or_inc", in2out(boolean_indexing_set_or_inc), "jax", position=100 "jax_boolean_indexing_set_or_inc",
in2out(boolean_indexing_set_or_inc),
"jax",
position=100,
) )
...@@ -67,12 +73,63 @@ def boolean_indexing_sum(fgraph, node): ...@@ -67,12 +73,63 @@ def boolean_indexing_sum(fgraph, node):
if not isinstance(cond, TensorVariable): if not isinstance(cond, TensorVariable):
return return
if not cond.type.dtype == 'bool': if not cond.type.dtype == "bool":
return return
out = at.sum(at.where(cond, x, 0)) out = at.sum(at.where(cond, x, 0))
return out.owner.outputs return out.owner.outputs
optdb.register( optdb.register(
"jax_boolean_indexing_sum", in2out(boolean_indexing_sum), "jax", position=100 "jax_boolean_indexing_sum", in2out(boolean_indexing_sum), "jax", position=100
) )
@node_rewriter([Reshape])
def shape_parameter_as_tuple(fgraph, node):
"""Replace `MakeVector` and `DimShuffle` (when used to transform a scalar
into a 1d vector) when they are found as the input of a `shape`
parameter by `JAXShapeTuple` during transpilation.
The JAX implementations of `MakeVector` and `DimShuffle` always return JAX
`TracedArrays`, but JAX only accepts concrete values as inputs for the `size`
or `shape` parameter. When these `Op`s are used to convert scalar or tuple
inputs, however, we can avoid tracing by making them return a tuple of their
inputs instead.
Note that JAX does not accept scalar inputs for the `size` or `shape`
parameters, and this rewrite also ensures that scalar inputs are turned into
tuples during transpilation.
"""
from pytensor.link.jax.dispatch.shape import JAXShapeTuple
shape_arg = node.inputs[1]
shape_node = shape_arg.owner
if shape_node is None:
return
if isinstance(shape_node.op, JAXShapeTuple):
return
if isinstance(shape_node.op, MakeVector) or (
isinstance(shape_node.op, DimShuffle)
and shape_node.op.input_broadcastable == ()
and shape_node.op.new_order == ("x",)
):
# Here PyTensor converted a tuple or list to a tensor
new_shape_args = JAXShapeTuple()(*shape_node.inputs)
new_inputs = list(node.inputs)
new_inputs[1] = new_shape_args
new_node = node.clone_with_new_inputs(new_inputs)
return new_node.outputs
optdb.register(
"jax_shape_parameter_as_tuple",
in2out(shape_parameter_as_tuple),
"jax",
position=100,
)
...@@ -45,30 +45,34 @@ def test_jax_specify_shape(): ...@@ -45,30 +45,34 @@ def test_jax_specify_shape():
compare_jax_and_py(x_fg, []) compare_jax_and_py(x_fg, [])
def test_jax_Reshape(): def test_jax_Reshape_constant():
a = vector("a") a = vector("a")
x = reshape(a, (2, 2)) x = reshape(a, (2, 2))
x_fg = FunctionGraph([a], [x]) x_fg = FunctionGraph([a], [x])
compare_jax_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)]) compare_jax_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)])
# Test breaking "omnistaging" changes in JAX.
# See https://github.com/tensorflow/probability/commit/782d0c64eb774b9aac54a1c8488e4f1f96fbbc68 def test_jax_Reshape_concrete_shape():
"""JAX should compile when a concrete value is passed for the `shape` parameter."""
a = vector("a")
x = reshape(a, a.shape)
x_fg = FunctionGraph([a], [x])
compare_jax_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)])
x = reshape(a, (a.shape[0] // 2, a.shape[0] // 2)) x = reshape(a, (a.shape[0] // 2, a.shape[0] // 2))
x_fg = FunctionGraph([a], [x]) x_fg = FunctionGraph([a], [x])
with pytest.raises( compare_jax_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)])
TypeError,
match="Shapes must be 1D sequences of concrete values of integer type",
): @pytest.mark.xfail(
compare_jax_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)]) reason="`shape_at` should be specified as a static argument", strict=True
)
b = iscalar("b") def test_jax_Reshape_shape_graph_input():
x = reshape(a, (b, b)) a = vector("a")
x_fg = FunctionGraph([a, b], [x]) shape_at = iscalar("b")
with pytest.raises( x = reshape(a, (shape_at, shape_at))
TypeError, x_fg = FunctionGraph([a, shape_at], [x])
match="Shapes must be 1D sequences of concrete values of integer type", compare_jax_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX), 2])
):
compare_jax_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX), 2])
def test_jax_compile_ops(): def test_jax_compile_ops():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论