提交 a8303a0d authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Cleanup elemwise_cgen.py

上级 49daa85b
from textwrap import dedent, indent
from pytensor.configdefaults import config from pytensor.configdefaults import config
...@@ -8,12 +10,10 @@ def make_declare(loop_orders, dtypes, sub): ...@@ -8,12 +10,10 @@ def make_declare(loop_orders, dtypes, sub):
""" """
decl = "" decl = ""
for i, (loop_order, dtype) in enumerate(zip(loop_orders, dtypes)): for i, (loop_order, dtype) in enumerate(zip(loop_orders, dtypes)):
var = sub[f"lv{int(i)}"] # input name corresponding to ith loop variable var = sub[f"lv{i}"] # input name corresponding to ith loop variable
# we declare an iteration variable # we declare an iteration variable
# and an integer for the number of dimensions # and an integer for the number of dimensions
decl += f""" decl += f"{dtype}* {var}_iter;\n"
{dtype}* {var}_iter;
"""
for j, value in enumerate(loop_order): for j, value in enumerate(loop_order):
if value != "x": if value != "x":
# If the dimension is not broadcasted, we declare # If the dimension is not broadcasted, we declare
...@@ -21,17 +21,15 @@ def make_declare(loop_orders, dtypes, sub): ...@@ -21,17 +21,15 @@ def make_declare(loop_orders, dtypes, sub):
# the stride in that dimension, # the stride in that dimension,
# and the jump from an iteration to the next # and the jump from an iteration to the next
decl += f""" decl += f"""
npy_intp {var}_n{int(value)}; npy_intp {var}_n{value};
ssize_t {var}_stride{int(value)}; ssize_t {var}_stride{value};
int {var}_jump{int(value)}_{int(j)}; int {var}_jump{value}_{j};
""" """
else: else:
# if the dimension is broadcasted, we only need # if the dimension is broadcasted, we only need
# the jump (arbitrary length and stride = 0) # the jump (arbitrary length and stride = 0)
decl += f""" decl += f"int {var}_jump{value}_{j};\n"
int {var}_jump{value}_{int(j)};
"""
return decl return decl
...@@ -39,7 +37,7 @@ def make_declare(loop_orders, dtypes, sub): ...@@ -39,7 +37,7 @@ def make_declare(loop_orders, dtypes, sub):
def make_checks(loop_orders, dtypes, sub): def make_checks(loop_orders, dtypes, sub):
init = "" init = ""
for i, (loop_order, dtype) in enumerate(zip(loop_orders, dtypes)): for i, (loop_order, dtype) in enumerate(zip(loop_orders, dtypes)):
var = f"%(lv{int(i)})s" var = sub[f"lv{i}"]
# List of dimensions of var that are not broadcasted # List of dimensions of var that are not broadcasted
nonx = [x for x in loop_order if x != "x"] nonx = [x for x in loop_order if x != "x"]
if nonx: if nonx:
...@@ -47,12 +45,14 @@ def make_checks(loop_orders, dtypes, sub): ...@@ -47,12 +45,14 @@ def make_checks(loop_orders, dtypes, sub):
# this is a check that the number of dimensions of the # this is a check that the number of dimensions of the
# tensor is as expected. # tensor is as expected.
min_nd = max(nonx) + 1 min_nd = max(nonx) + 1
init += f""" init += dedent(
f"""
if (PyArray_NDIM({var}) < {min_nd}) {{ if (PyArray_NDIM({var}) < {min_nd}) {{
PyErr_SetString(PyExc_ValueError, "Not enough dimensions on input."); PyErr_SetString(PyExc_ValueError, "Not enough dimensions on input.");
%(fail)s {indent(sub["fail"], " " * 12)}
}} }}
""" """
)
# In loop j, adjust represents the difference of values of the # In loop j, adjust represents the difference of values of the
# data pointer between the beginning and the end of the # data pointer between the beginning and the end of the
...@@ -75,9 +75,7 @@ def make_checks(loop_orders, dtypes, sub): ...@@ -75,9 +75,7 @@ def make_checks(loop_orders, dtypes, sub):
adjust = f"{var}_n{index}*{var}_stride{index}" adjust = f"{var}_n{index}*{var}_stride{index}"
else: else:
jump = f"-({adjust})" jump = f"-({adjust})"
init += f""" init += f"{var}_jump{index}_{j} = {jump};\n"
{var}_jump{index}_{j} = {jump};
"""
adjust = "0" adjust = "0"
check = "" check = ""
...@@ -101,34 +99,36 @@ def make_checks(loop_orders, dtypes, sub): ...@@ -101,34 +99,36 @@ def make_checks(loop_orders, dtypes, sub):
j0, x0 = to_compare[0] j0, x0 = to_compare[0]
for j, x in to_compare[1:]: for j, x in to_compare[1:]:
check += f""" check += dedent(
if (%(lv{j0})s_n{x0} != %(lv{j})s_n{x}) f"""
if ({sub[f"lv{j0}"]}_n{x0} != {sub[f"lv{j}"]}_n{x})
{{ {{
if (%(lv{j0})s_n{x0} == 1 || %(lv{j})s_n{x} == 1) if ({sub[f"lv{j0}"]}_n{x0} == 1 || {sub[f"lv{j}"]}_n{x} == 1)
{{ {{
PyErr_Format(PyExc_ValueError, "{runtime_broadcast_error_msg}", PyErr_Format(PyExc_ValueError, "{runtime_broadcast_error_msg}",
{j0}, {j0},
{x0}, {x0},
(long long int) %(lv{j0})s_n{x0}, (long long int) {sub[f"lv{j0}"]}_n{x0},
{j}, {j},
{x}, {x},
(long long int) %(lv{j})s_n{x} (long long int) {sub[f"lv{j}"]}_n{x}
); );
}} else {{ }} else {{
PyErr_Format(PyExc_ValueError, "Input dimension mismatch: (input[%%i].shape[%%i] = %%lld, input[%%i].shape[%%i] = %%lld)", PyErr_Format(PyExc_ValueError, "Input dimension mismatch: (input[%%i].shape[%%i] = %%lld, input[%%i].shape[%%i] = %%lld)",
{j0}, {j0},
{x0}, {x0},
(long long int) %(lv{j0})s_n{x0}, (long long int) {sub[f"lv{j0}"]}_n{x0},
{j}, {j},
{x}, {x},
(long long int) %(lv{j})s_n{x} (long long int) {sub[f"lv{j}"]}_n{x}
); );
}} }}
%(fail)s {sub["fail"]}
}} }}
""" """
)
return init % sub + check % sub return init + check
def compute_output_dims_lengths(array_name: str, loop_orders, sub) -> str: def compute_output_dims_lengths(array_name: str, loop_orders, sub) -> str:
...@@ -144,7 +144,7 @@ def compute_output_dims_lengths(array_name: str, loop_orders, sub) -> str: ...@@ -144,7 +144,7 @@ def compute_output_dims_lengths(array_name: str, loop_orders, sub) -> str:
# Borrow the length of the first non-broadcastable input dimension # Borrow the length of the first non-broadcastable input dimension
for j, candidate in enumerate(candidates): for j, candidate in enumerate(candidates):
if candidate != "x": if candidate != "x":
var = sub[f"lv{int(j)}"] var = sub[f"lv{j}"]
dims_c_code += f"{array_name}[{i}] = {var}_n{candidate};\n" dims_c_code += f"{array_name}[{i}] = {var}_n{candidate};\n"
break break
# If none is non-broadcastable, the output dimension has a length of 1 # If none is non-broadcastable, the output dimension has a length of 1
...@@ -177,13 +177,14 @@ def make_alloc(loop_orders, dtype, sub, fortran="0"): ...@@ -177,13 +177,14 @@ def make_alloc(loop_orders, dtype, sub, fortran="0"):
# way that its contiguous dimensions match one of the input's # way that its contiguous dimensions match one of the input's
# contiguous dimensions, or the dimension with the smallest # contiguous dimensions, or the dimension with the smallest
# stride. Right now, it is allocated to be C_CONTIGUOUS. # stride. Right now, it is allocated to be C_CONTIGUOUS.
return f""" return dedent(
f"""
{{ {{
npy_intp dims[{nd}]; npy_intp dims[{nd}];
//npy_intp* dims = (npy_intp*)malloc({nd} * sizeof(npy_intp));
{init_dims} {init_dims}
if (!{olv}) {{ if (!{olv}) {{
{olv} = (PyArrayObject*)PyArray_EMPTY({nd}, dims, {olv} = (PyArrayObject*)PyArray_EMPTY({nd},
dims,
{type}, {type},
{fortran}); {fortran});
}} }}
...@@ -206,6 +207,7 @@ def make_alloc(loop_orders, dtype, sub, fortran="0"): ...@@ -206,6 +207,7 @@ def make_alloc(loop_orders, dtype, sub, fortran="0"):
}} }}
}} }}
""" """
)
def make_loop(loop_orders, dtypes, loop_tasks, sub, openmp=None): def make_loop(loop_orders, dtypes, loop_tasks, sub, openmp=None):
...@@ -235,11 +237,11 @@ def make_loop(loop_orders, dtypes, loop_tasks, sub, openmp=None): ...@@ -235,11 +237,11 @@ def make_loop(loop_orders, dtypes, loop_tasks, sub, openmp=None):
""" """
def loop_over(preloop, code, indices, i): def loop_over(preloop, code, indices, i):
iterv = f"ITER_{int(i)}" iterv = f"ITER_{i}"
update = "" update = ""
suitable_n = "1" suitable_n = "1"
for j, index in enumerate(indices): for j, index in enumerate(indices):
var = sub[f"lv{int(j)}"] var = sub[f"lv{j}"]
dtype = dtypes[j] dtype = dtypes[j]
update += f"{dtype} &{var}_i = * ( {var}_iter + {iterv} * {var}_jump{index}_{i} );\n" update += f"{dtype} &{var}_i = * ( {var}_iter + {iterv} * {var}_jump{index}_{i} );\n"
...@@ -305,13 +307,13 @@ def make_reordered_loop( ...@@ -305,13 +307,13 @@ def make_reordered_loop(
nnested = len(init_loop_orders[0]) nnested = len(init_loop_orders[0])
# This is the var from which we'll get the loop order # This is the var from which we'll get the loop order
ovar = sub[f"lv{int(olv_index)}"] ovar = sub[f"lv{olv_index}"]
# The loops are ordered by (decreasing) absolute values of ovar's strides. # The loops are ordered by (decreasing) absolute values of ovar's strides.
# The first element of each pair is the absolute value of the stride # The first element of each pair is the absolute value of the stride
# The second element correspond to the index in the initial loop order # The second element correspond to the index in the initial loop order
order_loops = f""" order_loops = f"""
std::vector< std::pair<int, int> > {ovar}_loops({int(nnested)}); std::vector< std::pair<int, int> > {ovar}_loops({nnested});
std::vector< std::pair<int, int> >::iterator {ovar}_loops_it = {ovar}_loops.begin(); std::vector< std::pair<int, int> >::iterator {ovar}_loops_it = {ovar}_loops.begin();
""" """
...@@ -319,7 +321,7 @@ def make_reordered_loop( ...@@ -319,7 +321,7 @@ def make_reordered_loop(
for i, index in enumerate(init_loop_orders[olv_index]): for i, index in enumerate(init_loop_orders[olv_index]):
if index != "x": if index != "x":
order_loops += f""" order_loops += f"""
{ovar}_loops_it->first = abs(PyArray_STRIDES({ovar})[{int(index)}]); {ovar}_loops_it->first = abs(PyArray_STRIDES({ovar})[{index}]);
""" """
else: else:
# Stride is 0 when dimension is broadcastable # Stride is 0 when dimension is broadcastable
...@@ -328,7 +330,7 @@ def make_reordered_loop( ...@@ -328,7 +330,7 @@ def make_reordered_loop(
""" """
order_loops += f""" order_loops += f"""
{ovar}_loops_it->second = {int(i)}; {ovar}_loops_it->second = {i};
++{ovar}_loops_it; ++{ovar}_loops_it;
""" """
...@@ -352,7 +354,7 @@ def make_reordered_loop( ...@@ -352,7 +354,7 @@ def make_reordered_loop(
for i in range(nnested): for i in range(nnested):
declare_totals += f""" declare_totals += f"""
int TOTAL_{int(i)} = init_totals[{ovar}_loops_it->second]; int TOTAL_{i} = init_totals[{ovar}_loops_it->second];
++{ovar}_loops_it; ++{ovar}_loops_it;
""" """
...@@ -365,7 +367,7 @@ def make_reordered_loop( ...@@ -365,7 +367,7 @@ def make_reordered_loop(
specified loop_order. specified loop_order.
""" """
var = sub[f"lv{int(i)}"] var = sub[f"lv{i}"]
r = [] r = []
for index in loop_order: for index in loop_order:
# Note: the stride variable is not declared for broadcasted variables # Note: the stride variable is not declared for broadcasted variables
...@@ -383,7 +385,7 @@ def make_reordered_loop( ...@@ -383,7 +385,7 @@ def make_reordered_loop(
) )
declare_strides = f""" declare_strides = f"""
int init_strides[{int(nvars)}][{int(nnested)}] = {{ int init_strides[{nvars}][{nnested}] = {{
{strides} {strides}
}};""" }};"""
...@@ -394,33 +396,33 @@ def make_reordered_loop( ...@@ -394,33 +396,33 @@ def make_reordered_loop(
""" """
for i in range(nvars): for i in range(nvars):
var = sub[f"lv{int(i)}"] var = sub[f"lv{i}"]
declare_strides += f""" declare_strides += f"""
{ovar}_loops_rit = {ovar}_loops.rbegin();""" {ovar}_loops_rit = {ovar}_loops.rbegin();"""
for j in reversed(range(nnested)): for j in reversed(range(nnested)):
declare_strides += f""" declare_strides += f"""
int {var}_stride_l{int(j)} = init_strides[{int(i)}][{ovar}_loops_rit->second]; int {var}_stride_l{j} = init_strides[{i}][{ovar}_loops_rit->second];
++{ovar}_loops_rit; ++{ovar}_loops_rit;
""" """
declare_iter = "" declare_iter = ""
for i, dtype in enumerate(dtypes): for i, dtype in enumerate(dtypes):
var = sub[f"lv{int(i)}"] var = sub[f"lv{i}"]
declare_iter += f"{var}_iter = ({dtype}*)(PyArray_DATA({var}));\n" declare_iter += f"{var}_iter = ({dtype}*)(PyArray_DATA({var}));\n"
pointer_update = "" pointer_update = ""
for j, dtype in enumerate(dtypes): for j, dtype in enumerate(dtypes):
var = sub[f"lv{int(j)}"] var = sub[f"lv{j}"]
pointer_update += f"{dtype} &{var}_i = * ( {var}_iter" pointer_update += f"{dtype} &{var}_i = * ( {var}_iter"
for i in reversed(range(nnested)): for i in reversed(range(nnested)):
iterv = f"ITER_{int(i)}" iterv = f"ITER_{i}"
pointer_update += f"+{var}_stride_l{int(i)}*{iterv}" pointer_update += f"+{var}_stride_l{i}*{iterv}"
pointer_update += ");\n" pointer_update += ");\n"
loop = inner_task loop = inner_task
for i in reversed(range(nnested)): for i in reversed(range(nnested)):
iterv = f"ITER_{int(i)}" iterv = f"ITER_{i}"
total = f"TOTAL_{int(i)}" total = f"TOTAL_{i}"
update = "" update = ""
forloop = "" forloop = ""
# The pointers are defined only in the most inner loop # The pointers are defined only in the most inner loop
...@@ -434,36 +436,14 @@ def make_reordered_loop( ...@@ -434,36 +436,14 @@ def make_reordered_loop(
loop = f""" loop = f"""
{forloop} {forloop}
{{ // begin loop {int(i)} {{ // begin loop {i}
{update} {update}
{loop} {loop}
}} // end loop {int(i)} }} // end loop {i}
""" """
return f"{{\n{order_loops}\n{declare_totals}\n{declare_strides}\n{declare_iter}\n{loop}\n}}\n" code = "\n".join((order_loops, declare_totals, declare_strides, declare_iter, loop))
return f"{{\n{code}\n}}\n"
# print make_declare(((0, 1, 2, 3), ('x', 1, 0, 3), ('x', 'x', 'x', 0)),
# ('double', 'int', 'float'),
# dict(lv0='x', lv1='y', lv2='z', fail="FAIL;"))
# print make_checks(((0, 1, 2, 3), ('x', 1, 0, 3), ('x', 'x', 'x', 0)),
# ('double', 'int', 'float'),
# dict(lv0='x', lv1='y', lv2='z', fail="FAIL;"))
# print make_alloc(((0, 1, 2, 3), ('x', 1, 0, 3), ('x', 'x', 'x', 0)),
# 'double',
# dict(olv='out', lv0='x', lv1='y', lv2='z', fail="FAIL;"))
# print make_loop(((0, 1, 2, 3), ('x', 1, 0, 3), ('x', 'x', 'x', 0)),
# ('double', 'int', 'float'),
# (("C00;", "C%01;"), ("C10;", "C11;"), ("C20;", "C21;"), ("C30;", "C31;"),"C4;"),
# dict(lv0='x', lv1='y', lv2='z', fail="FAIL;"))
# print make_loop(((0, 1, 2, 3), (3, 'x', 0, 'x'), (0, 'x', 'x', 'x')),
# ('double', 'int', 'float'),
# (("C00;", "C01;"), ("C10;", "C11;"), ("C20;", "C21;"), ("C30;", "C31;"),"C4;"),
# dict(lv0='x', lv1='y', lv2='z', fail="FAIL;"))
################## ##################
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论