Unverified 提交 0f8c81c3 authored 作者: Kaustubh's avatar Kaustubh 提交者: GitHub

Fixed: Inv Numba implementation now correctly returns non-integral inverses of…

Fixed: Inv Numba implementation now correctly returns non-integral inverses of integer inputs (#627)
上级 e6c2fbc9
...@@ -21,6 +21,7 @@ from aesara.scalar.basic import ( ...@@ -21,6 +21,7 @@ from aesara.scalar.basic import (
Clip, Clip,
Composite, Composite,
Identity, Identity,
Inv,
Mul, Mul,
ScalarOp, ScalarOp,
Second, Second,
...@@ -170,3 +171,12 @@ def numba_funcify_Second(op, node, **kwargs): ...@@ -170,3 +171,12 @@ def numba_funcify_Second(op, node, **kwargs):
return y return y
return second return second
@numba_funcify.register(Inv)
def numba_funcify_Inv(op, node, **kwargs):
@numba.njit(inline="always")
def inv(x):
return 1 / x
return inv
...@@ -796,6 +796,25 @@ def test_Cast(v, dtype): ...@@ -796,6 +796,25 @@ def test_Cast(v, dtype):
) )
@pytest.mark.parametrize(
"v, dtype",
[
(set_test_value(aet.iscalar(), np.array(10, dtype="int32")), aesb.float64),
],
)
def test_Inv(v, dtype):
g = aesb.inv(v)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"v, shape, ndim", "v, shape, ndim",
[ [
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论