提交 c4702cd9 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Numba Alloc: Patch so it works inside a Blockwise

上级 c513419c
...@@ -74,13 +74,13 @@ def numba_funcify_Alloc(op, node, **kwargs): ...@@ -74,13 +74,13 @@ def numba_funcify_Alloc(op, node, **kwargs):
f'if val.shape[{-i - 1}] == 1 and scalar_shape[{-i - 1}] != 1: raise ValueError("{Alloc._runtime_broadcast_error_msg}")' f'if val.shape[{-i - 1}] == 1 and scalar_shape[{-i - 1}] != 1: raise ValueError("{Alloc._runtime_broadcast_error_msg}")'
) )
check_runtime_broadcast_src = indent("\n".join(check_runtime_broadcast), " " * 4) check_runtime_broadcast_src = indent("\n".join(check_runtime_broadcast), " " * 4)
dtype = node.inputs[0].type.dtype
alloc_def_src = f""" alloc_def_src = f"""
def alloc(val, {", ".join(shape_var_names)}): def alloc(val, {", ".join(shape_var_names)}):
{shapes_to_items_src} {shapes_to_items_src}
scalar_shape = {create_tuple_string(shape_var_item_names)} scalar_shape = {create_tuple_string(shape_var_item_names)}
{check_runtime_broadcast_src} {check_runtime_broadcast_src}
res = np.empty(scalar_shape, dtype=val.dtype) res = np.empty(scalar_shape, dtype=np.{dtype})
res[...] = val res[...] = val
return res return res
""" """
...@@ -88,10 +88,12 @@ def alloc(val, {", ".join(shape_var_names)}): ...@@ -88,10 +88,12 @@ def alloc(val, {", ".join(shape_var_names)}):
alloc_def_src, alloc_def_src,
"alloc", "alloc",
globals() | {"np": np}, globals() | {"np": np},
write_to_disk=True,
) )
cache_version = -1
cache_key = sha256( cache_key = sha256(
str((type(op), node.inputs[0].type.broadcastable)).encode() str((type(op), node.inputs[0].type.broadcastable, cache_version)).encode()
).hexdigest() ).hexdigest()
return numba_basic.numba_njit(alloc_fn), cache_key return numba_basic.numba_njit(alloc_fn), cache_key
......
...@@ -2,8 +2,8 @@ import numpy as np ...@@ -2,8 +2,8 @@ import numpy as np
import pytest import pytest
from pytensor import function from pytensor import function
from pytensor.tensor import tensor, tensor3 from pytensor.tensor import lvector, tensor, tensor3
from pytensor.tensor.basic import ARange from pytensor.tensor.basic import Alloc, ARange, constant
from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape
from pytensor.tensor.nlinalg import SVD, Det from pytensor.tensor.nlinalg import SVD, Det
from pytensor.tensor.slinalg import Cholesky, cholesky from pytensor.tensor.slinalg import Cholesky, cholesky
...@@ -70,3 +70,13 @@ def test_repeated_args(): ...@@ -70,3 +70,13 @@ def test_repeated_args():
final_node = fn.maker.fgraph.outputs[0].owner final_node = fn.maker.fgraph.outputs[0].owner
assert isinstance(final_node.op, BlockwiseWithCoreShape) assert isinstance(final_node.op, BlockwiseWithCoreShape)
assert final_node.inputs[0] is final_node.inputs[1] assert final_node.inputs[0] is final_node.inputs[1]
def test_blockwise_alloc():
val = lvector("val")
out = Blockwise(Alloc(), signature="(),(),()->(2,3)")(
val, constant(2, dtype="int64"), constant(3, dtype="int64")
)
assert out.type.ndim == 3
compare_numba_and_py([val], [out], [np.arange(5)], eval_obj_mode=False)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论