提交 9352166c authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Numba does not output numpy scalars

上级 a602a8ec
......@@ -6,6 +6,7 @@ import pytensor.tensor as pt
from pytensor.compile.mode import Mode
from pytensor.graph.fg import FunctionGraph
from pytensor.link.c.basic import DualLinker
from pytensor.link.numba import NumbaLinker
from pytensor.scalar.basic import (
EQ,
ComplexError,
......@@ -368,7 +369,9 @@ class TestUpgradeToFloat:
outi = fi(x_val)
outf = ff(x_val)
assert outi.dtype == outf.dtype, "incorrect dtype"
if not isinstance(ff.maker.linker, NumbaLinker):
# Numba doesn't return numpy scalars
assert outi.dtype == outf.dtype, "incorrect dtype"
assert np.allclose(outi, outf), "insufficient precision"
@staticmethod
......@@ -389,7 +392,9 @@ class TestUpgradeToFloat:
outi = fi(x_val, y_val)
outf = ff(x_val, y_val)
assert outi.dtype == outf.dtype, "incorrect dtype"
if not isinstance(ff.maker.linker, NumbaLinker):
# Numba doesn't return numpy scalars
assert outi.dtype == outf.dtype, "incorrect dtype"
assert np.allclose(outi, outf), "insufficient precision"
def test_true_div(self):
......@@ -414,7 +419,9 @@ class TestUpgradeToFloat:
outi = fi(x_val, y_val)
outf = ff(x_val, y_val)
assert outi.dtype == outf.dtype, "incorrect dtype"
if not isinstance(ff.maker.linker, NumbaLinker):
# Numba doesn't return numpy scalars
assert outi.dtype == outf.dtype, "incorrect dtype"
assert np.allclose(outi, outf), "insufficient precision"
def test_unary(self):
......
......@@ -18,6 +18,7 @@ from pytensor.gradient import grad, hessian
from pytensor.graph.basic import Apply, equal_computations
from pytensor.graph.op import Op
from pytensor.graph.replace import clone_replace
from pytensor.link.numba import NumbaLinker
from pytensor.raise_op import Assert
from pytensor.scalar import autocast_float, autocast_float_as
from pytensor.tensor import NoneConst, vectorize
......@@ -2193,15 +2194,19 @@ def test_ScalarFromTensor(cast_policy):
assert ss.owner.op is scalar_from_tensor
assert ss.type.dtype == tc.type.dtype
v = eval_outputs([ss])
mode = get_default_mode()
v = eval_outputs([ss], mode=mode)
assert v == 56
assert v.shape == ()
if cast_policy == "custom":
assert isinstance(v, np.int8)
elif cast_policy == "numpy+floatX":
assert isinstance(v, np.int64)
if isinstance(mode.linker, NumbaLinker):
# Numba doesn't return numpy scalars
assert isinstance(v, int)
else:
assert v.shape == ()
if cast_policy == "custom":
assert isinstance(v, np.int8)
elif cast_policy == "numpy+floatX":
assert isinstance(v, np.int64)
pts = lscalar()
ss = scalar_from_tensor(pts)
......@@ -2209,8 +2214,11 @@ def test_ScalarFromTensor(cast_policy):
fff = function([pts], ss)
v = fff(np.asarray(5))
assert v == 5
assert isinstance(v, np.int64)
assert v.shape == ()
if isinstance(mode.linker, NumbaLinker):
assert isinstance(v, int)
else:
assert isinstance(v, np.int64)
assert v.shape == ()
with pytest.raises(TypeError):
scalar_from_tensor(vector())
......
......@@ -3,6 +3,7 @@ import sys
import warnings
from copy import copy, deepcopy
from functools import wraps
from numbers import Number
import numpy as np
import pytest
......@@ -259,6 +260,10 @@ class InferShapeTester:
numeric_outputs = outputs_function(*numeric_inputs)
numeric_shapes = shapes_function(*numeric_inputs)
for out, shape in zip(numeric_outputs, numeric_shapes, strict=True):
if not hasattr(out, "shape"):
# Numba downcasts scalars to native Python types, which don't have shape
assert isinstance(out, Number)
out = np.asarray(out)
assert np.all(out.shape == shape), (out.shape, shape)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论