提交 03858395 authored 作者: kc611's avatar kc611 提交者: Brandon T. Willard

Added support for while looping in Numba scan

上级 8437fa94
......@@ -69,7 +69,6 @@ def numba_funcify_Scan(op, node, **kwargs):
outer_in_non_seqs_names = input_names[p_outer_in_non_seqs:]
inner_in_indexed = []
inner_out_indexed = []
allocate_mem_to_nit_sot = ""
for _name in outer_in_seqs_names:
......@@ -81,6 +80,7 @@ def numba_funcify_Scan(op, node, **kwargs):
name_to_input_map = dict(zip(input_names, node.inputs[1:]))
mit_sot_name_to_taps = dict(zip(outer_in_mit_sot_names, mit_sot_in_taps))
inner_out_name_to_index = {}
for _name in outer_in_feedback_names:
if _name in outer_in_mit_sot_names:
curr_taps = mit_sot_name_to_taps[_name]
......@@ -90,8 +90,7 @@ def numba_funcify_Scan(op, node, **kwargs):
index = idx_to_str(_tap - min_tap)
inner_in_indexed.append(_name + index)
index = idx_to_str(-min_tap)
inner_out_indexed.append(_name + index)
inner_out_name_to_index[_name] = -min_tap
if _name in outer_in_sit_sot_names:
# Note that the outputs with single taps which are not
......@@ -100,12 +99,10 @@ def numba_funcify_Scan(op, node, **kwargs):
# constant as follows
index = "[i]"
inner_in_indexed.append(_name + index)
index = "[i+1]"
inner_out_indexed.append(_name + index)
inner_out_name_to_index[_name] = 1
if _name in outer_in_nit_sot_names:
index = "[i]"
inner_out_indexed.append(_name + index)
inner_out_name_to_index[_name] = 0
# In case of nit-sots we are provided shape of the array
# instead of actual arrays like other cases, hence we
# allocate space for the results accordingly.
......@@ -114,19 +111,35 @@ def numba_funcify_Scan(op, node, **kwargs):
"""
# The non_seqs are passed to inner function as-is
inner_in_indexed += outer_in_non_seqs_names
inner_out_indexed = [
_name + idx_to_str(idx) for _name, idx in inner_out_name_to_index.items()
]
while_logic = ""
if op.as_while:
# The inner function will be returning a boolean as last argument
inner_out_indexed.append("while_flag")
while_logic += """
if while_flag:
"""
for _name, idx in inner_out_name_to_index.items():
while_logic += f"""
{_name} = {_name}[:i+{idx+1}]
"""
while_logic += """
break
"""
global_env = locals()
global_env["np"] = np
scan_op_src = f"""
def scan(n_steps, {", ".join(input_names)}):
{allocate_mem_to_nit_sot}
for i in range(n_steps):
inner_args = {create_tuple_string(inner_in_indexed)}
{create_tuple_string(inner_out_indexed)} = numba_aet_inner_func(*inner_args)
{while_logic}
return {create_arg_string(
outer_in_mit_sot_names +
outer_in_sit_sot_names +
......
......@@ -29,6 +29,7 @@ from aesara.link.numba.dispatch import basic as numba_basic
from aesara.link.numba.linker import NumbaLinker
from aesara.scalar.basic import Composite
from aesara.scan.basic import scan
from aesara.scan.utils import until
from aesara.tensor import blas
from aesara.tensor import elemwise as aet_elemwise
from aesara.tensor import extra_ops, nlinalg, slinalg
......@@ -3044,3 +3045,23 @@ def test_scan_tap_output():
np.arange(20, 31, dtype=config.floatX),
]
compare_numba_and_py(out_fg, test_input_vals)
def test_scan_while():
def power_of_2(previous_power, max_value):
return previous_power * 2, until(previous_power * 2 > max_value)
max_value = aet.scalar()
values, _ = scan(
power_of_2,
outputs_info=aet.constant(1.0),
non_sequences=max_value,
n_steps=1024,
)
out_fg = FunctionGraph([max_value], [values])
test_input_vals = [
np.array(45).astype(config.floatX),
]
compare_numba_and_py(out_fg, test_input_vals)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论