提交 9352166c authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Numba does not output numpy scalars

上级 a602a8ec
...@@ -6,6 +6,7 @@ import pytensor.tensor as pt ...@@ -6,6 +6,7 @@ import pytensor.tensor as pt
from pytensor.compile.mode import Mode from pytensor.compile.mode import Mode
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.link.c.basic import DualLinker from pytensor.link.c.basic import DualLinker
from pytensor.link.numba import NumbaLinker
from pytensor.scalar.basic import ( from pytensor.scalar.basic import (
EQ, EQ,
ComplexError, ComplexError,
...@@ -368,6 +369,8 @@ class TestUpgradeToFloat: ...@@ -368,6 +369,8 @@ class TestUpgradeToFloat:
outi = fi(x_val) outi = fi(x_val)
outf = ff(x_val) outf = ff(x_val)
if not isinstance(ff.maker.linker, NumbaLinker):
# Numba doesn't return numpy scalars
assert outi.dtype == outf.dtype, "incorrect dtype" assert outi.dtype == outf.dtype, "incorrect dtype"
assert np.allclose(outi, outf), "insufficient precision" assert np.allclose(outi, outf), "insufficient precision"
...@@ -389,6 +392,8 @@ class TestUpgradeToFloat: ...@@ -389,6 +392,8 @@ class TestUpgradeToFloat:
outi = fi(x_val, y_val) outi = fi(x_val, y_val)
outf = ff(x_val, y_val) outf = ff(x_val, y_val)
if not isinstance(ff.maker.linker, NumbaLinker):
# Numba doesn't return numpy scalars
assert outi.dtype == outf.dtype, "incorrect dtype" assert outi.dtype == outf.dtype, "incorrect dtype"
assert np.allclose(outi, outf), "insufficient precision" assert np.allclose(outi, outf), "insufficient precision"
...@@ -414,6 +419,8 @@ class TestUpgradeToFloat: ...@@ -414,6 +419,8 @@ class TestUpgradeToFloat:
outi = fi(x_val, y_val) outi = fi(x_val, y_val)
outf = ff(x_val, y_val) outf = ff(x_val, y_val)
if not isinstance(ff.maker.linker, NumbaLinker):
# Numba doesn't return numpy scalars
assert outi.dtype == outf.dtype, "incorrect dtype" assert outi.dtype == outf.dtype, "incorrect dtype"
assert np.allclose(outi, outf), "insufficient precision" assert np.allclose(outi, outf), "insufficient precision"
......
...@@ -18,6 +18,7 @@ from pytensor.gradient import grad, hessian ...@@ -18,6 +18,7 @@ from pytensor.gradient import grad, hessian
from pytensor.graph.basic import Apply, equal_computations from pytensor.graph.basic import Apply, equal_computations
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.graph.replace import clone_replace from pytensor.graph.replace import clone_replace
from pytensor.link.numba import NumbaLinker
from pytensor.raise_op import Assert from pytensor.raise_op import Assert
from pytensor.scalar import autocast_float, autocast_float_as from pytensor.scalar import autocast_float, autocast_float_as
from pytensor.tensor import NoneConst, vectorize from pytensor.tensor import NoneConst, vectorize
...@@ -2193,11 +2194,15 @@ def test_ScalarFromTensor(cast_policy): ...@@ -2193,11 +2194,15 @@ def test_ScalarFromTensor(cast_policy):
assert ss.owner.op is scalar_from_tensor assert ss.owner.op is scalar_from_tensor
assert ss.type.dtype == tc.type.dtype assert ss.type.dtype == tc.type.dtype
v = eval_outputs([ss]) mode = get_default_mode()
v = eval_outputs([ss], mode=mode)
assert v == 56 assert v == 56
if isinstance(mode.linker, NumbaLinker):
# Numba doesn't return numpy scalars
assert isinstance(v, int)
else:
assert v.shape == () assert v.shape == ()
if cast_policy == "custom": if cast_policy == "custom":
assert isinstance(v, np.int8) assert isinstance(v, np.int8)
elif cast_policy == "numpy+floatX": elif cast_policy == "numpy+floatX":
...@@ -2209,6 +2214,9 @@ def test_ScalarFromTensor(cast_policy): ...@@ -2209,6 +2214,9 @@ def test_ScalarFromTensor(cast_policy):
fff = function([pts], ss) fff = function([pts], ss)
v = fff(np.asarray(5)) v = fff(np.asarray(5))
assert v == 5 assert v == 5
if isinstance(mode.linker, NumbaLinker):
assert isinstance(v, int)
else:
assert isinstance(v, np.int64) assert isinstance(v, np.int64)
assert v.shape == () assert v.shape == ()
......
...@@ -3,6 +3,7 @@ import sys ...@@ -3,6 +3,7 @@ import sys
import warnings import warnings
from copy import copy, deepcopy from copy import copy, deepcopy
from functools import wraps from functools import wraps
from numbers import Number
import numpy as np import numpy as np
import pytest import pytest
...@@ -259,6 +260,10 @@ class InferShapeTester: ...@@ -259,6 +260,10 @@ class InferShapeTester:
numeric_outputs = outputs_function(*numeric_inputs) numeric_outputs = outputs_function(*numeric_inputs)
numeric_shapes = shapes_function(*numeric_inputs) numeric_shapes = shapes_function(*numeric_inputs)
for out, shape in zip(numeric_outputs, numeric_shapes, strict=True): for out, shape in zip(numeric_outputs, numeric_shapes, strict=True):
if not hasattr(out, "shape"):
# Numba downcasts scalars to native Python types, which don't have shape
assert isinstance(out, Number)
out = np.asarray(out)
assert np.all(out.shape == shape), (out.shape, shape) assert np.all(out.shape == shape), (out.shape, shape)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论