提交 ac14c1dd authored 作者: kc611's avatar kc611 提交者: Brandon T. Willard

Fixed float typecasting for Dot implementation in Numba module

上级 263a7c71
...@@ -539,10 +539,12 @@ def int_to_float_fn(inputs, out_dtype): ...@@ -539,10 +539,12 @@ def int_to_float_fn(inputs, out_dtype):
return x.astype(args_dtype) return x.astype(args_dtype)
else: else:
args_dtype_sz = max([_arg.type.numpy_dtype.itemsize for _arg in inputs])
args_dtype = np.dtype(f"f{args_dtype_sz}")
@numba.njit(inline="always") @numba.njit(inline="always")
def inputs_cast(x): def inputs_cast(x):
return x return x.astype(args_dtype)
return inputs_cast return inputs_cast
......
...@@ -176,11 +176,7 @@ def eval_python_only(fn_inputs, fgraph, inputs): ...@@ -176,11 +176,7 @@ def eval_python_only(fn_inputs, fgraph, inputs):
_ = aesara_numba_fn(*inputs) _ = aesara_numba_fn(*inputs)
def compare_numba_and_py( def compare_numba_and_py(fgraph, inputs, assert_fn=None):
fgraph,
inputs,
assert_fn=None,
):
"""Function to compare python graph output and Numba compiled output for testing equality """Function to compare python graph output and Numba compiled output for testing equality
In the tests below computational graphs are defined in Aesara. These graphs are then passed to In the tests below computational graphs are defined in Aesara. These graphs are then passed to
...@@ -1839,6 +1835,15 @@ def test_BroadcastTo(x, shape, exc): ...@@ -1839,6 +1835,15 @@ def test_BroadcastTo(x, shape, exc):
set_test_value(aet.vector(), rng.random(size=(2,)).astype(config.floatX)), set_test_value(aet.vector(), rng.random(size=(2,)).astype(config.floatX)),
None, None,
), ),
(
set_test_value(
aet.matrix(dtype="float64"), rng.random(size=(3, 2)).astype("float64")
),
set_test_value(
aet.vector(dtype="float32"), rng.random(size=(2,)).astype("float32")
),
None,
),
( (
set_test_value(aet.lmatrix(), rng.poisson(size=(3, 2))), set_test_value(aet.lmatrix(), rng.poisson(size=(3, 2))),
set_test_value(aet.fvector(), rng.random(size=(2,)).astype("float32")), set_test_value(aet.fvector(), rng.random(size=(2,)).astype("float32")),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论