提交 b3b68618 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Only MIT-MOT require working on buffers directly

Using JAX Scan machinery to create MIT-SOT, SIT-SOT, and NIT-SOT buffers for us seems to be more performant than working directly on the pre-allocated buffers and reading/writing at every iteration. There is no machinery to work with MIT-MOT directly (just like in PyTensor user-facing Scan).
上级 14e6c781
......@@ -8,7 +8,6 @@ from pytensor import function, ifelse, shared
from pytensor.compile import get_mode
from pytensor.configdefaults import config
from pytensor.graph import Apply, Op
from pytensor.link.jax.dispatch.basic import jax_funcify
from pytensor.scan import until
from pytensor.scan.basic import scan
from pytensor.scan.op import Scan
......@@ -335,6 +334,9 @@ def test_default_mode_excludes_incompatible_rewrites():
def test_dynamic_sequence_length():
# Imported here to not trigger import of JAX in non-JAX CI jobs
from pytensor.link.jax.dispatch.basic import jax_funcify
class IncWithoutStaticShape(Op):
def make_node(self, x):
x = pt.as_tensor_variable(x)
......@@ -358,10 +360,10 @@ def test_dynamic_sequence_length():
assert sum(isinstance(node.op, Scan) for node in f.maker.fgraph.apply_nodes) == 1
np.testing.assert_allclose(f([[1, 2, 3]]), np.array([[2, 3, 4]]))
with pytest.raises(ValueError):
f(np.zeros((0, 3)))
# This works if we use JAX scan internally, but not if we use a fori_loop with a buffer allocated by us
np.testing.assert_allclose(f(np.zeros((0, 3))), np.empty((0, 3)))
# But should be fine with static shape
# With known static shape we should always manage, regardless of the internal implementation
out2, _ = scan(
lambda x: pt.specify_shape(inc_without_static_shape(x), x.shape),
outputs_info=[None],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论