提交 ae729683 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix wrap_jax when there is a mix of statically known and unknown shapes

上级 42587563
...@@ -9,7 +9,7 @@ from pytensor.compile.function import function ...@@ -9,7 +9,7 @@ from pytensor.compile.function import function
from pytensor.compile.mode import Mode from pytensor.compile.mode import Mode
from pytensor.gradient import DisconnectedType from pytensor.gradient import DisconnectedType
from pytensor.graph import Apply, Op, Variable from pytensor.graph import Apply, Op, Variable
from pytensor.tensor.basic import infer_static_shape from pytensor.tensor.basic import as_tensor, infer_static_shape
from pytensor.tensor.type import TensorType from pytensor.tensor.type import TensorType
...@@ -384,7 +384,7 @@ def _find_output_types( ...@@ -384,7 +384,7 @@ def _find_output_types(
try: try:
shape_evaluation_function = function( shape_evaluation_function = function(
[], [],
resolved_input_shapes, [as_tensor(s, dtype="int64") for s in resolved_input_shapes],
on_unused_input="ignore", on_unused_input="ignore",
mode=Mode(linker="py", optimizer="fast_compile"), mode=Mode(linker="py", optimizer="fast_compile"),
) )
...@@ -394,7 +394,7 @@ def _find_output_types( ...@@ -394,7 +394,7 @@ def _find_output_types(
"Please provide inputs with fully determined shapes by " "Please provide inputs with fully determined shapes by "
"calling pt.specify_shape." "calling pt.specify_shape."
) from e ) from e
resolved_input_shapes = shape_evaluation_function() resolved_input_shapes = [tuple(s) for s in shape_evaluation_function()]
# Determine output types using jax.eval_shape with dummy inputs # Determine output types using jax.eval_shape with dummy inputs
output_metadata_storage = {} output_metadata_storage = {}
...@@ -422,6 +422,7 @@ def _find_output_types( ...@@ -422,6 +422,7 @@ def _find_output_types(
output_static = output_metadata_storage["output_static"] output_static = output_metadata_storage["output_static"]
# If we used shape evaluation, set all output shapes to unknown # If we used shape evaluation, set all output shapes to unknown
# TODO: This is throwing away potential static shape information.
if requires_shape_evaluation: if requires_shape_evaluation:
output_types = [ output_types = [
TensorType( TensorType(
......
...@@ -559,3 +559,15 @@ class TestDtypes: ...@@ -559,3 +559,15 @@ class TestDtypes:
compare_jax_and_py([x, y], [out, *grad_out], test_values) compare_jax_and_py([x, y], [out, *grad_out], test_values)
else: else:
compare_jax_and_py([x, y], [out], test_values) compare_jax_and_py([x, y], [out], test_values)
def test_mixed_static_shape():
x_unknown = shared(np.ones((3,)))
x_known = shared(np.ones((4,)), shape=(4,))
def f(x1, x2):
return jax.numpy.concatenate([x1, x2])
assert wrap_jax(f)(x_known, x_known).type.shape == (8,)
assert wrap_jax(f)(x_known, x_unknown).type.shape == (None,)
assert wrap_jax(f)(x_unknown, x_known).type.shape == (None,)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论