提交 ac14c1dd authored 作者: kc611's avatar kc611 提交者: Brandon T. Willard

Fixed float typecasting for Dot implementation in Numba module

上级 263a7c71
......@@ -539,10 +539,12 @@ def int_to_float_fn(inputs, out_dtype):
return x.astype(args_dtype)
else:
args_dtype_sz = max([_arg.type.numpy_dtype.itemsize for _arg in inputs])
args_dtype = np.dtype(f"f{args_dtype_sz}")
@numba.njit(inline="always")
def inputs_cast(x):
return x
return x.astype(args_dtype)
return inputs_cast
......
......@@ -176,11 +176,7 @@ def eval_python_only(fn_inputs, fgraph, inputs):
_ = aesara_numba_fn(*inputs)
def compare_numba_and_py(
fgraph,
inputs,
assert_fn=None,
):
def compare_numba_and_py(fgraph, inputs, assert_fn=None):
"""Function to compare python graph output and Numba compiled output for testing equality
In the tests below computational graphs are defined in Aesara. These graphs are then passed to
......@@ -1839,6 +1835,15 @@ def test_BroadcastTo(x, shape, exc):
set_test_value(aet.vector(), rng.random(size=(2,)).astype(config.floatX)),
None,
),
(
set_test_value(
aet.matrix(dtype="float64"), rng.random(size=(3, 2)).astype("float64")
),
set_test_value(
aet.vector(dtype="float32"), rng.random(size=(2,)).astype("float32")
),
None,
),
(
set_test_value(aet.lmatrix(), rng.poisson(size=(3, 2))),
set_test_value(aet.fvector(), rng.random(size=(2,)).astype("float32")),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论