提交 03858395 authored 作者: kc611's avatar kc611 提交者: Brandon T. Willard

Added support for while looping in Numba scan

上级 8437fa94
...@@ -69,7 +69,6 @@ def numba_funcify_Scan(op, node, **kwargs): ...@@ -69,7 +69,6 @@ def numba_funcify_Scan(op, node, **kwargs):
outer_in_non_seqs_names = input_names[p_outer_in_non_seqs:] outer_in_non_seqs_names = input_names[p_outer_in_non_seqs:]
inner_in_indexed = [] inner_in_indexed = []
inner_out_indexed = []
allocate_mem_to_nit_sot = "" allocate_mem_to_nit_sot = ""
for _name in outer_in_seqs_names: for _name in outer_in_seqs_names:
...@@ -81,6 +80,7 @@ def numba_funcify_Scan(op, node, **kwargs): ...@@ -81,6 +80,7 @@ def numba_funcify_Scan(op, node, **kwargs):
name_to_input_map = dict(zip(input_names, node.inputs[1:])) name_to_input_map = dict(zip(input_names, node.inputs[1:]))
mit_sot_name_to_taps = dict(zip(outer_in_mit_sot_names, mit_sot_in_taps)) mit_sot_name_to_taps = dict(zip(outer_in_mit_sot_names, mit_sot_in_taps))
inner_out_name_to_index = {}
for _name in outer_in_feedback_names: for _name in outer_in_feedback_names:
if _name in outer_in_mit_sot_names: if _name in outer_in_mit_sot_names:
curr_taps = mit_sot_name_to_taps[_name] curr_taps = mit_sot_name_to_taps[_name]
...@@ -90,8 +90,7 @@ def numba_funcify_Scan(op, node, **kwargs): ...@@ -90,8 +90,7 @@ def numba_funcify_Scan(op, node, **kwargs):
index = idx_to_str(_tap - min_tap) index = idx_to_str(_tap - min_tap)
inner_in_indexed.append(_name + index) inner_in_indexed.append(_name + index)
index = idx_to_str(-min_tap) inner_out_name_to_index[_name] = -min_tap
inner_out_indexed.append(_name + index)
if _name in outer_in_sit_sot_names: if _name in outer_in_sit_sot_names:
# Note that the outputs with single taps which are not # Note that the outputs with single taps which are not
...@@ -100,12 +99,10 @@ def numba_funcify_Scan(op, node, **kwargs): ...@@ -100,12 +99,10 @@ def numba_funcify_Scan(op, node, **kwargs):
# constant as follows # constant as follows
index = "[i]" index = "[i]"
inner_in_indexed.append(_name + index) inner_in_indexed.append(_name + index)
index = "[i+1]" inner_out_name_to_index[_name] = 1
inner_out_indexed.append(_name + index)
if _name in outer_in_nit_sot_names: if _name in outer_in_nit_sot_names:
index = "[i]" inner_out_name_to_index[_name] = 0
inner_out_indexed.append(_name + index)
# In case of nit-sots we are provided shape of the array # In case of nit-sots we are provided shape of the array
# instead of actual arrays like other cases, hence we # instead of actual arrays like other cases, hence we
# allocate space for the results accordingly. # allocate space for the results accordingly.
...@@ -114,19 +111,35 @@ def numba_funcify_Scan(op, node, **kwargs): ...@@ -114,19 +111,35 @@ def numba_funcify_Scan(op, node, **kwargs):
""" """
# The non_seqs are passed to inner function as-is # The non_seqs are passed to inner function as-is
inner_in_indexed += outer_in_non_seqs_names inner_in_indexed += outer_in_non_seqs_names
inner_out_indexed = [
_name + idx_to_str(idx) for _name, idx in inner_out_name_to_index.items()
]
while_logic = ""
if op.as_while:
# The inner function will be returning a boolean as last argument
inner_out_indexed.append("while_flag")
while_logic += """
if while_flag:
"""
for _name, idx in inner_out_name_to_index.items():
while_logic += f"""
{_name} = {_name}[:i+{idx+1}]
"""
while_logic += """
break
"""
global_env = locals() global_env = locals()
global_env["np"] = np global_env["np"] = np
scan_op_src = f""" scan_op_src = f"""
def scan(n_steps, {", ".join(input_names)}): def scan(n_steps, {", ".join(input_names)}):
{allocate_mem_to_nit_sot} {allocate_mem_to_nit_sot}
for i in range(n_steps): for i in range(n_steps):
inner_args = {create_tuple_string(inner_in_indexed)} inner_args = {create_tuple_string(inner_in_indexed)}
{create_tuple_string(inner_out_indexed)} = numba_aet_inner_func(*inner_args) {create_tuple_string(inner_out_indexed)} = numba_aet_inner_func(*inner_args)
{while_logic}
return {create_arg_string( return {create_arg_string(
outer_in_mit_sot_names + outer_in_mit_sot_names +
outer_in_sit_sot_names + outer_in_sit_sot_names +
......
...@@ -29,6 +29,7 @@ from aesara.link.numba.dispatch import basic as numba_basic ...@@ -29,6 +29,7 @@ from aesara.link.numba.dispatch import basic as numba_basic
from aesara.link.numba.linker import NumbaLinker from aesara.link.numba.linker import NumbaLinker
from aesara.scalar.basic import Composite from aesara.scalar.basic import Composite
from aesara.scan.basic import scan from aesara.scan.basic import scan
from aesara.scan.utils import until
from aesara.tensor import blas from aesara.tensor import blas
from aesara.tensor import elemwise as aet_elemwise from aesara.tensor import elemwise as aet_elemwise
from aesara.tensor import extra_ops, nlinalg, slinalg from aesara.tensor import extra_ops, nlinalg, slinalg
...@@ -3044,3 +3045,23 @@ def test_scan_tap_output(): ...@@ -3044,3 +3045,23 @@ def test_scan_tap_output():
np.arange(20, 31, dtype=config.floatX), np.arange(20, 31, dtype=config.floatX),
] ]
compare_numba_and_py(out_fg, test_input_vals) compare_numba_and_py(out_fg, test_input_vals)
def test_scan_while():
def power_of_2(previous_power, max_value):
return previous_power * 2, until(previous_power * 2 > max_value)
max_value = aet.scalar()
values, _ = scan(
power_of_2,
outputs_info=aet.constant(1.0),
non_sequences=max_value,
n_steps=1024,
)
out_fg = FunctionGraph([max_value], [values])
test_input_vals = [
np.array(45).astype(config.floatX),
]
compare_numba_and_py(out_fg, test_input_vals)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论