提交 c4702cd9 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Numba Alloc: Patch so it works inside a Blockwise

上级 c513419c
......@@ -74,13 +74,13 @@ def numba_funcify_Alloc(op, node, **kwargs):
f'if val.shape[{-i - 1}] == 1 and scalar_shape[{-i - 1}] != 1: raise ValueError("{Alloc._runtime_broadcast_error_msg}")'
)
check_runtime_broadcast_src = indent("\n".join(check_runtime_broadcast), " " * 4)
dtype = node.inputs[0].type.dtype
alloc_def_src = f"""
def alloc(val, {", ".join(shape_var_names)}):
{shapes_to_items_src}
scalar_shape = {create_tuple_string(shape_var_item_names)}
{check_runtime_broadcast_src}
res = np.empty(scalar_shape, dtype=val.dtype)
res = np.empty(scalar_shape, dtype=np.{dtype})
res[...] = val
return res
"""
......@@ -88,10 +88,12 @@ def alloc(val, {", ".join(shape_var_names)}):
alloc_def_src,
"alloc",
globals() | {"np": np},
write_to_disk=True,
)
cache_version = -1
cache_key = sha256(
str((type(op), node.inputs[0].type.broadcastable)).encode()
str((type(op), node.inputs[0].type.broadcastable, cache_version)).encode()
).hexdigest()
return numba_basic.numba_njit(alloc_fn), cache_key
......
......@@ -2,8 +2,8 @@ import numpy as np
import pytest
from pytensor import function
from pytensor.tensor import tensor, tensor3
from pytensor.tensor.basic import ARange
from pytensor.tensor import lvector, tensor, tensor3
from pytensor.tensor.basic import Alloc, ARange, constant
from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape
from pytensor.tensor.nlinalg import SVD, Det
from pytensor.tensor.slinalg import Cholesky, cholesky
......@@ -70,3 +70,13 @@ def test_repeated_args():
final_node = fn.maker.fgraph.outputs[0].owner
assert isinstance(final_node.op, BlockwiseWithCoreShape)
assert final_node.inputs[0] is final_node.inputs[1]
def test_blockwise_alloc():
val = lvector("val")
out = Blockwise(Alloc(), signature="(),(),()->(2,3)")(
val, constant(2, dtype="int64"), constant(3, dtype="int64")
)
assert out.type.ndim == 3
compare_numba_and_py([val], [out], [np.arange(5)], eval_obj_mode=False)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论