提交 799a10fd authored 作者: kc611's avatar kc611 提交者: Brandon T. Willard

Fixed failing Numba Scan when n_steps are provided explicitly

上级 15fba0e3
......@@ -241,6 +241,11 @@ def create_tuple_string(x):
return f"({args})"
def create_arg_string(x):
args = ", ".join(x)
return args
@singledispatch
def numba_typify(data, dtype=None, **kwargs):
return data
......
import numba
import numpy as np
from numba import types
from numba.extending import overload
from aesara.graph.fg import FunctionGraph
from aesara.link.numba.dispatch.basic import create_tuple_string, numba_funcify
from aesara.link.numba.dispatch.basic import (
create_arg_string,
create_tuple_string,
numba_funcify,
)
from aesara.link.utils import compile_function_src
from aesara.scan.op import Scan
......@@ -16,6 +22,16 @@ def idx_to_str(idx):
return res + "]"
@overload(range)
def array0d_range(x):
if isinstance(x, types.Array) and x.ndim == 0:
def range_arr(x):
return range(x.item())
return range_arr
@numba_funcify.register(Scan)
def numba_funcify_Scan(op, node, **kwargs):
inner_fg = FunctionGraph(op.inputs, op.outputs)
......@@ -57,7 +73,9 @@ def numba_funcify_Scan(op, node, **kwargs):
allocate_mem_to_nit_sot = ""
for _name in outer_in_seqs_names:
# TODO:Index sould be updating according to sequence's taps
# A sequence with multiple taps is provided as multiple modified
# input sequences to the Scan Op sliced appropriately
# to keep following the logic of a normal sequence.
index = "[i]"
inner_in_indexed.append(_name + index)
......@@ -66,7 +84,7 @@ def numba_funcify_Scan(op, node, **kwargs):
for _name in outer_in_feedback_names:
if _name in outer_in_mit_sot_names:
curr_taps = mit_sot_name_to_taps[_name]
min_tap = min(*curr_taps)
min_tap = min(curr_taps)
for _tap in curr_taps:
index = idx_to_str(_tap - min_tap)
......@@ -76,18 +94,23 @@ def numba_funcify_Scan(op, node, **kwargs):
inner_out_indexed.append(_name + index)
if _name in outer_in_sit_sot_names:
# TODO: Input according to taps
# Note that the outputs with single taps which are not
# -1 are (for instance taps = [-2]) are classified
# as mit-sot so the code for handling sit-sots remains
# constant as follows
index = "[i]"
inner_in_indexed.append(_name + index)
index = "[i+1]"
inner_out_indexed.append(_name + index)
if _name in outer_in_nit_sot_names:
# TODO: Allocate this properly
index = "[i]"
inner_out_indexed.append(_name + index)
# In case of nit-sots we are provided shape of the array
# instead of actual arrays like other cases, hence we
# allocate space for the results accordingly.
allocate_mem_to_nit_sot += f"""
{_name} = np.zeros(n_steps)
{_name} = np.zeros({_name}.item())
"""
# The non_seqs are passed to inner function as-is
inner_in_indexed += outer_in_non_seqs_names
......@@ -97,23 +120,18 @@ def numba_funcify_Scan(op, node, **kwargs):
scan_op_src = f"""
def scan(n_steps, {", ".join(input_names)}):
outer_in_seqs = {create_tuple_string(outer_in_seqs_names)}
outer_in_mit_sot = {create_tuple_string(outer_in_mit_sot_names)}
outer_in_sit_sot = {create_tuple_string(outer_in_sit_sot_names)}
outer_in_shared = {create_tuple_string(outer_in_shared_names)}
outer_in_non_seqs = {create_tuple_string(outer_in_non_seqs_names)}
{allocate_mem_to_nit_sot}
outer_in_nit_sot = {create_tuple_string(outer_in_nit_sot_names)}
for i in range(n_steps):
inner_args = {create_tuple_string(inner_in_indexed)}
{create_tuple_string(inner_out_indexed)} = numba_aet_inner_func(*inner_args)
return (
outer_in_mit_sot +
outer_in_sit_sot +
outer_in_nit_sot
)
return {create_arg_string(
outer_in_mit_sot_names +
outer_in_sit_sot_names +
outer_in_nit_sot_names
)}
"""
scalar_op_fn = compile_function_src(scan_op_src, "scan", global_env)
......
......@@ -2988,17 +2988,17 @@ def test_scan_tap_output():
a_aet = aet.scalar("a")
a_aet.tag.test_value = 10.0
b_aet = aet.arange(10).astype(config.floatX)
b_aet = aet.arange(11).astype(config.floatX)
b_aet.name = "b"
c_aet = aet.arange(20, 30, dtype=config.floatX)
c_aet = aet.arange(20, 31, dtype=config.floatX)
c_aet.name = "c"
def input_step_fn(b, c, x_tm1, y_tm1, y_tm3, a):
def input_step_fn(b, b2, c, x_tm1, y_tm1, y_tm3, a):
x_tm1.name = "x_tm1"
y_tm1.name = "y_tm1"
y_tm3.name = "y_tm3"
y_t = (y_tm1 + y_tm3) * a + b
y_t = (y_tm1 + y_tm3) * a + b + b2
z_t = y_t * c
x_t = x_tm1 + 1
x_t.name = "x_t"
......@@ -3007,7 +3007,16 @@ def test_scan_tap_output():
scan_res, _ = scan(
fn=input_step_fn,
sequences=[b_aet, c_aet],
sequences=[
{
"input": b_aet,
"taps": [-1, -2],
},
{
"input": c_aet,
"taps": [-2],
},
],
outputs_info=[
{
"initial": aet.as_tensor_variable(0.0, dtype=config.floatX),
......@@ -3022,7 +3031,7 @@ def test_scan_tap_output():
None,
],
non_sequences=[a_aet],
# n_steps=10,
n_steps=10,
name="yz_scan",
strict=True,
)
......@@ -3031,7 +3040,7 @@ def test_scan_tap_output():
test_input_vals = [
np.array(10.0).astype(config.floatX),
np.arange(10, dtype=config.floatX),
np.arange(20, 30, dtype=config.floatX),
np.arange(11, dtype=config.floatX),
np.arange(20, 31, dtype=config.floatX),
]
compare_numba_and_py(out_fg, test_input_vals)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论