提交 799a10fd authored 作者: kc611's avatar kc611 提交者: Brandon T. Willard

Fixed failing Numba Scan when n_steps are provided explicitly

上级 15fba0e3
...@@ -241,6 +241,11 @@ def create_tuple_string(x): ...@@ -241,6 +241,11 @@ def create_tuple_string(x):
return f"({args})" return f"({args})"
def create_arg_string(x):
args = ", ".join(x)
return args
@singledispatch @singledispatch
def numba_typify(data, dtype=None, **kwargs): def numba_typify(data, dtype=None, **kwargs):
return data return data
......
import numba import numba
import numpy as np import numpy as np
from numba import types
from numba.extending import overload
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.link.numba.dispatch.basic import create_tuple_string, numba_funcify from aesara.link.numba.dispatch.basic import (
create_arg_string,
create_tuple_string,
numba_funcify,
)
from aesara.link.utils import compile_function_src from aesara.link.utils import compile_function_src
from aesara.scan.op import Scan from aesara.scan.op import Scan
...@@ -16,6 +22,16 @@ def idx_to_str(idx): ...@@ -16,6 +22,16 @@ def idx_to_str(idx):
return res + "]" return res + "]"
@overload(range)
def array0d_range(x):
if isinstance(x, types.Array) and x.ndim == 0:
def range_arr(x):
return range(x.item())
return range_arr
@numba_funcify.register(Scan) @numba_funcify.register(Scan)
def numba_funcify_Scan(op, node, **kwargs): def numba_funcify_Scan(op, node, **kwargs):
inner_fg = FunctionGraph(op.inputs, op.outputs) inner_fg = FunctionGraph(op.inputs, op.outputs)
...@@ -57,7 +73,9 @@ def numba_funcify_Scan(op, node, **kwargs): ...@@ -57,7 +73,9 @@ def numba_funcify_Scan(op, node, **kwargs):
allocate_mem_to_nit_sot = "" allocate_mem_to_nit_sot = ""
for _name in outer_in_seqs_names: for _name in outer_in_seqs_names:
# TODO:Index sould be updating according to sequence's taps # A sequence with multiple taps is provided as multiple modified
# input sequences to the Scan Op sliced appropriately
# to keep following the logic of a normal sequence.
index = "[i]" index = "[i]"
inner_in_indexed.append(_name + index) inner_in_indexed.append(_name + index)
...@@ -66,7 +84,7 @@ def numba_funcify_Scan(op, node, **kwargs): ...@@ -66,7 +84,7 @@ def numba_funcify_Scan(op, node, **kwargs):
for _name in outer_in_feedback_names: for _name in outer_in_feedback_names:
if _name in outer_in_mit_sot_names: if _name in outer_in_mit_sot_names:
curr_taps = mit_sot_name_to_taps[_name] curr_taps = mit_sot_name_to_taps[_name]
min_tap = min(*curr_taps) min_tap = min(curr_taps)
for _tap in curr_taps: for _tap in curr_taps:
index = idx_to_str(_tap - min_tap) index = idx_to_str(_tap - min_tap)
...@@ -76,18 +94,23 @@ def numba_funcify_Scan(op, node, **kwargs): ...@@ -76,18 +94,23 @@ def numba_funcify_Scan(op, node, **kwargs):
inner_out_indexed.append(_name + index) inner_out_indexed.append(_name + index)
if _name in outer_in_sit_sot_names: if _name in outer_in_sit_sot_names:
# TODO: Input according to taps # Note that the outputs with single taps which are not
# -1 are (for instance taps = [-2]) are classified
# as mit-sot so the code for handling sit-sots remains
# constant as follows
index = "[i]" index = "[i]"
inner_in_indexed.append(_name + index) inner_in_indexed.append(_name + index)
index = "[i+1]" index = "[i+1]"
inner_out_indexed.append(_name + index) inner_out_indexed.append(_name + index)
if _name in outer_in_nit_sot_names: if _name in outer_in_nit_sot_names:
# TODO: Allocate this properly
index = "[i]" index = "[i]"
inner_out_indexed.append(_name + index) inner_out_indexed.append(_name + index)
# In case of nit-sots we are provided shape of the array
# instead of actual arrays like other cases, hence we
# allocate space for the results accordingly.
allocate_mem_to_nit_sot += f""" allocate_mem_to_nit_sot += f"""
{_name} = np.zeros(n_steps) {_name} = np.zeros({_name}.item())
""" """
# The non_seqs are passed to inner function as-is # The non_seqs are passed to inner function as-is
inner_in_indexed += outer_in_non_seqs_names inner_in_indexed += outer_in_non_seqs_names
...@@ -97,23 +120,18 @@ def numba_funcify_Scan(op, node, **kwargs): ...@@ -97,23 +120,18 @@ def numba_funcify_Scan(op, node, **kwargs):
scan_op_src = f""" scan_op_src = f"""
def scan(n_steps, {", ".join(input_names)}): def scan(n_steps, {", ".join(input_names)}):
outer_in_seqs = {create_tuple_string(outer_in_seqs_names)}
outer_in_mit_sot = {create_tuple_string(outer_in_mit_sot_names)}
outer_in_sit_sot = {create_tuple_string(outer_in_sit_sot_names)}
outer_in_shared = {create_tuple_string(outer_in_shared_names)}
outer_in_non_seqs = {create_tuple_string(outer_in_non_seqs_names)}
{allocate_mem_to_nit_sot} {allocate_mem_to_nit_sot}
outer_in_nit_sot = {create_tuple_string(outer_in_nit_sot_names)}
for i in range(n_steps): for i in range(n_steps):
inner_args = {create_tuple_string(inner_in_indexed)} inner_args = {create_tuple_string(inner_in_indexed)}
{create_tuple_string(inner_out_indexed)} = numba_aet_inner_func(*inner_args) {create_tuple_string(inner_out_indexed)} = numba_aet_inner_func(*inner_args)
return ( return {create_arg_string(
outer_in_mit_sot + outer_in_mit_sot_names +
outer_in_sit_sot + outer_in_sit_sot_names +
outer_in_nit_sot outer_in_nit_sot_names
) )}
""" """
scalar_op_fn = compile_function_src(scan_op_src, "scan", global_env) scalar_op_fn = compile_function_src(scan_op_src, "scan", global_env)
......
...@@ -2988,17 +2988,17 @@ def test_scan_tap_output(): ...@@ -2988,17 +2988,17 @@ def test_scan_tap_output():
a_aet = aet.scalar("a") a_aet = aet.scalar("a")
a_aet.tag.test_value = 10.0 a_aet.tag.test_value = 10.0
b_aet = aet.arange(10).astype(config.floatX) b_aet = aet.arange(11).astype(config.floatX)
b_aet.name = "b" b_aet.name = "b"
c_aet = aet.arange(20, 30, dtype=config.floatX) c_aet = aet.arange(20, 31, dtype=config.floatX)
c_aet.name = "c" c_aet.name = "c"
def input_step_fn(b, c, x_tm1, y_tm1, y_tm3, a): def input_step_fn(b, b2, c, x_tm1, y_tm1, y_tm3, a):
x_tm1.name = "x_tm1" x_tm1.name = "x_tm1"
y_tm1.name = "y_tm1" y_tm1.name = "y_tm1"
y_tm3.name = "y_tm3" y_tm3.name = "y_tm3"
y_t = (y_tm1 + y_tm3) * a + b y_t = (y_tm1 + y_tm3) * a + b + b2
z_t = y_t * c z_t = y_t * c
x_t = x_tm1 + 1 x_t = x_tm1 + 1
x_t.name = "x_t" x_t.name = "x_t"
...@@ -3007,7 +3007,16 @@ def test_scan_tap_output(): ...@@ -3007,7 +3007,16 @@ def test_scan_tap_output():
scan_res, _ = scan( scan_res, _ = scan(
fn=input_step_fn, fn=input_step_fn,
sequences=[b_aet, c_aet], sequences=[
{
"input": b_aet,
"taps": [-1, -2],
},
{
"input": c_aet,
"taps": [-2],
},
],
outputs_info=[ outputs_info=[
{ {
"initial": aet.as_tensor_variable(0.0, dtype=config.floatX), "initial": aet.as_tensor_variable(0.0, dtype=config.floatX),
...@@ -3022,7 +3031,7 @@ def test_scan_tap_output(): ...@@ -3022,7 +3031,7 @@ def test_scan_tap_output():
None, None,
], ],
non_sequences=[a_aet], non_sequences=[a_aet],
# n_steps=10, n_steps=10,
name="yz_scan", name="yz_scan",
strict=True, strict=True,
) )
...@@ -3031,7 +3040,7 @@ def test_scan_tap_output(): ...@@ -3031,7 +3040,7 @@ def test_scan_tap_output():
test_input_vals = [ test_input_vals = [
np.array(10.0).astype(config.floatX), np.array(10.0).astype(config.floatX),
np.arange(10, dtype=config.floatX), np.arange(11, dtype=config.floatX),
np.arange(20, 30, dtype=config.floatX), np.arange(20, 31, dtype=config.floatX),
] ]
compare_numba_and_py(out_fg, test_input_vals) compare_numba_and_py(out_fg, test_input_vals)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论