提交 14e6c781 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Reimplement JAX Scan dispatcher with MIT-MOT support

上级 97797975
...@@ -307,6 +307,17 @@ class ScanMethodsMixin: ...@@ -307,6 +307,17 @@ class ScanMethodsMixin:
n_taps = sum(len(x) for x in self.info.mit_mot_out_slices) n_taps = sum(len(x) for x in self.info.mit_mot_out_slices)
return list_outputs[:n_taps] return list_outputs[:n_taps]
def inner_mitmot_outs_grouped(self, list_outputs):
# Like inner_mitmot_outs but returns a list of lists, one per mitmot
# Instead of a flat list
n_taps = [len(x) for x in self.info.mit_mot_out_slices]
grouped_outs = []
offset = 0
for nt in n_taps:
grouped_outs.append(list_outputs[offset : offset + nt])
offset += nt
return grouped_outs
def outer_mitmot_outs(self, list_outputs): def outer_mitmot_outs(self, list_outputs):
return list_outputs[: self.info.n_mit_mot] return list_outputs[: self.info.n_mit_mot]
......
...@@ -7,6 +7,8 @@ import pytensor.tensor as pt ...@@ -7,6 +7,8 @@ import pytensor.tensor as pt
from pytensor import function, ifelse, shared from pytensor import function, ifelse, shared
from pytensor.compile import get_mode from pytensor.compile import get_mode
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph import Apply, Op
from pytensor.link.jax.dispatch.basic import jax_funcify
from pytensor.scan import until from pytensor.scan import until
from pytensor.scan.basic import scan from pytensor.scan.basic import scan
from pytensor.scan.op import Scan from pytensor.scan.op import Scan
...@@ -98,16 +100,26 @@ def test_scan_nit_sot(view): ...@@ -98,16 +100,26 @@ def test_scan_nit_sot(view):
assert len(scan_nodes) == 1 assert len(scan_nodes) == 1
@pytest.mark.xfail(raises=NotImplementedError)
def test_scan_mit_mot(): def test_scan_mit_mot():
xs = pt.vector("xs", shape=(10,)) def step(xtm1, ytm3, ytm1, rho):
ys, _ = scan( return (xtm1 + ytm1) * rho, ytm3 * (1 - rho) + ytm1 * rho
lambda xtm2, xtm1: (xtm2 + xtm1),
outputs_info=[{"initial": xs, "taps": [-2, -1]}], rho = pt.scalar("rho", dtype="float64")
x0 = pt.vector("xs", shape=(2,))
y0 = pt.vector("ys", shape=(3,))
[outs, _], _ = scan(
step,
outputs_info=[x0, {"initial": y0, "taps": [-3, -1]}],
non_sequences=[rho],
n_steps=10, n_steps=10,
) )
grads_wrt_xs = pt.grad(ys.sum(), wrt=xs) grads = pt.grad(outs.sum(), wrt=[x0, y0, rho])
compare_jax_and_py([xs], [grads_wrt_xs], [np.arange(10)]) compare_jax_and_py(
[x0, y0, rho],
grads,
[np.arange(2), np.array([0.5, 0.5, 0.5]), np.array(0.95)],
jax_mode=get_mode("JAX"),
)
def test_scan_update(): def test_scan_update():
...@@ -323,13 +335,41 @@ def test_default_mode_excludes_incompatible_rewrites(): ...@@ -323,13 +335,41 @@ def test_default_mode_excludes_incompatible_rewrites():
def test_dynamic_sequence_length(): def test_dynamic_sequence_length():
x = pt.tensor("x", shape=(None,)) class IncWithoutStaticShape(Op):
out, _ = scan(lambda x: x + 1, sequences=[x]) def make_node(self, x):
x = pt.as_tensor_variable(x)
return Apply(self, [x], [pt.tensor(shape=(None,) * x.type.ndim)])
def perform(self, node, inputs, outputs):
outputs[0][0] = inputs[0] + 1
@jax_funcify.register(IncWithoutStaticShape)
def _(op, **kwargs):
return lambda x: x + 1
inc_without_static_shape = IncWithoutStaticShape()
x = pt.tensor("x", shape=(None, 3))
out, _ = scan(
lambda x: inc_without_static_shape(x), outputs_info=[None], sequences=[x]
)
f = function([x], out, mode=get_mode("JAX").excluding("scan")) f = function([x], out, mode=get_mode("JAX").excluding("scan"))
assert sum(isinstance(node.op, Scan) for node in f.maker.fgraph.apply_nodes) == 1 assert sum(isinstance(node.op, Scan) for node in f.maker.fgraph.apply_nodes) == 1
np.testing.assert_allclose(f([]), []) np.testing.assert_allclose(f([[1, 2, 3]]), np.array([[2, 3, 4]]))
np.testing.assert_allclose(f([1, 2, 3]), np.array([2, 3, 4]))
with pytest.raises(ValueError):
f(np.zeros((0, 3)))
# But should be fine with static shape
out2, _ = scan(
lambda x: pt.specify_shape(inc_without_static_shape(x), x.shape),
outputs_info=[None],
sequences=[x],
)
f2 = function([x], out2, mode=get_mode("JAX").excluding("scan"))
np.testing.assert_allclose(f2([[1, 2, 3]]), np.array([[2, 3, 4]]))
np.testing.assert_allclose(f2(np.zeros((0, 3))), np.empty((0, 3)))
def SEIR_model_logp(): def SEIR_model_logp():
...@@ -499,9 +539,6 @@ def cyclical_reduction(): ...@@ -499,9 +539,6 @@ def cyclical_reduction():
@pytest.mark.parametrize("mode", ("0forward", "1backward", "2both")) @pytest.mark.parametrize("mode", ("0forward", "1backward", "2both"))
@pytest.mark.parametrize("model", [cyclical_reduction, SEIR_model_logp]) @pytest.mark.parametrize("model", [cyclical_reduction, SEIR_model_logp])
def test_scan_benchmark(model, mode, gradient_backend, benchmark): def test_scan_benchmark(model, mode, gradient_backend, benchmark):
if gradient_backend == "PYTENSOR" and mode in ("1backward", "2both"):
pytest.skip("PYTENSOR backend does not support backward mode yet")
model_dict = model() model_dict = model()
graph_inputs = model_dict["graph_inputs"] graph_inputs = model_dict["graph_inputs"]
differentiable_vars = model_dict["differentiable_vars"] differentiable_vars = model_dict["differentiable_vars"]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论