提交 ae729683 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix wrap_jax when there is a mix of statically known and unknown shapes

上级 42587563
......@@ -9,7 +9,7 @@ from pytensor.compile.function import function
from pytensor.compile.mode import Mode
from pytensor.gradient import DisconnectedType
from pytensor.graph import Apply, Op, Variable
from pytensor.tensor.basic import infer_static_shape
from pytensor.tensor.basic import as_tensor, infer_static_shape
from pytensor.tensor.type import TensorType
......@@ -384,7 +384,7 @@ def _find_output_types(
try:
shape_evaluation_function = function(
[],
resolved_input_shapes,
[as_tensor(s, dtype="int64") for s in resolved_input_shapes],
on_unused_input="ignore",
mode=Mode(linker="py", optimizer="fast_compile"),
)
......@@ -394,7 +394,7 @@ def _find_output_types(
"Please provide inputs with fully determined shapes by "
"calling pt.specify_shape."
) from e
resolved_input_shapes = shape_evaluation_function()
resolved_input_shapes = [tuple(s) for s in shape_evaluation_function()]
# Determine output types using jax.eval_shape with dummy inputs
output_metadata_storage = {}
......@@ -422,6 +422,7 @@ def _find_output_types(
output_static = output_metadata_storage["output_static"]
# If we used shape evaluation, set all output shapes to unknown
# TODO: This is throwing away potential static shape information.
if requires_shape_evaluation:
output_types = [
TensorType(
......
......@@ -559,3 +559,15 @@ class TestDtypes:
compare_jax_and_py([x, y], [out, *grad_out], test_values)
else:
compare_jax_and_py([x, y], [out], test_values)
def test_mixed_static_shape():
x_unknown = shared(np.ones((3,)))
x_known = shared(np.ones((4,)), shape=(4,))
def f(x1, x2):
return jax.numpy.concatenate([x1, x2])
assert wrap_jax(f)(x_known, x_known).type.shape == (8,)
assert wrap_jax(f)(x_known, x_unknown).type.shape == (None,)
assert wrap_jax(f)(x_unknown, x_known).type.shape == (None,)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论