提交 f203dd78 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Optimize Scan inner graph when compiling to Numba

上级 0451e008
......@@ -47,6 +47,13 @@ def array0d_range(x):
@numba_funcify.register(Scan)
def numba_funcify_Scan(op, node, **kwargs):
# Apply inner rewrites
# TODO: Not sure this is the right place to do this, should we have a rewrite that
# explicitly triggers the optimization of the inner graphs of Scan?
# The C-code deffers it to the make_thunk phase
rewriter = op.mode_instance.optimizer
rewriter(op.fgraph)
scan_inner_func = numba_basic.numba_njit(numba_funcify(op.fgraph))
outer_in_names_to_vars = {
......
......@@ -5,9 +5,12 @@ import pytensor.tensor as at
from pytensor import config, function, grad
from pytensor.compile.mode import Mode, get_mode
from pytensor.graph.fg import FunctionGraph
from pytensor.scalar import Log1p
from pytensor.scan.basic import scan
from pytensor.scan.op import Scan
from pytensor.scan.utils import until
from pytensor.tensor import log, vector
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.random.utils import RandomStream
from tests import unittest_tools as utt
from tests.link.numba.test_basic import compare_numba_and_py
......@@ -417,3 +420,25 @@ def test_mitmots_basic():
compare_numba_and_py(
out_fg, test_input_vals, numba_mode=numba_mode, py_mode=py_mode
)
def test_inner_graph_optimized():
"""Test that inner graph of Scan is optimized"""
xs = vector("xs")
seq, _ = scan(
fn=lambda x: log(1 + x),
sequences=[xs],
mode=get_mode("NUMBA"),
)
# Disable scan pushout, in which case the whole scan is replaced by an Elemwise
f = function([xs], seq, mode=get_mode("NUMBA").excluding("scan_pushout"))
(scan_node,) = [
node for node in f.maker.fgraph.apply_nodes if isinstance(node.op, Scan)
]
inner_scan_nodes = scan_node.op.fgraph.apply_nodes
assert len(inner_scan_nodes) == 1
(inner_scan_node,) = scan_node.op.fgraph.apply_nodes
assert isinstance(inner_scan_node.op, Elemwise) and isinstance(
inner_scan_node.op.scalar_op, Log1p
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论