提交 7b86ac5a authored 作者: Michael Osthege's avatar Michael Osthege 提交者: Thomas Wiecki

Repair references to bitwidth constants

上级 a0eb0ad0
......@@ -7,6 +7,7 @@ from theano.compile.function import pfunc
from theano.compile.io import In
from theano.compile.sharedvalue import shared
from theano.tensor import dmatrices, dmatrix, iscalar, lscalar
from theano.utils import PYTHON_INT_BITWIDTH
def data_of(s):
......@@ -570,7 +571,7 @@ class TestPfunc:
def test_default_updates_input(self):
x = shared(0)
y = shared(1)
if theano.configdefaults.python_int_bitwidth() == 32:
if PYTHON_INT_BITWIDTH == 32:
a = iscalar("a")
else:
a = lscalar("a")
......
......@@ -4,12 +4,13 @@ import pytest
import theano
from theano.compile.sharedvalue import SharedVariable, generic, shared
from theano.tensor import Tensor, TensorType
from theano.utils import PYTHON_INT_BITWIDTH
class TestSharedVariable:
def test_ctors(self):
if theano.configdefaults.python_int_bitwidth() == 32:
if PYTHON_INT_BITWIDTH == 32:
assert shared(7).type == theano.tensor.iscalar, shared(7).type
else:
assert shared(7).type == theano.tensor.lscalar, shared(7).type
......
......@@ -201,6 +201,7 @@ from theano.tensor import (
wvector,
zvector,
)
from theano.utils import PYTHON_INT_BITWIDTH
if config.mode == "FAST_COMPILE":
......@@ -5587,7 +5588,7 @@ class TestArithmeticCast:
class TestLongTensor:
def test_fit_int64(self):
bitwidth = theano.configdefaults.python_int_bitwidth()
bitwidth = PYTHON_INT_BITWIDTH
for exponent in range(bitwidth):
val = 2 ** exponent - 1
scalar_ct = constant(val)
......
......@@ -36,6 +36,7 @@ from theano.tensor.extra_ops import (
to_one_hot,
unravel_index,
)
from theano.utils import LOCAL_BITWIDTH, PYTHON_INT_BITWIDTH
def test_cpu_contiguous():
......@@ -105,7 +106,7 @@ class TestSearchsortedOp(utt.InferShapeTester):
def test_searchsortedOp_on_int_sorter(self):
compatible_types = ("int8", "int16", "int32")
if theano.configdefaults.python_int_bitwidth() == 64:
if PYTHON_INT_BITWIDTH == 64:
compatible_types += ("int64",)
# 'uint8', 'uint16', 'uint32', 'uint64')
for dtype in compatible_types:
......@@ -420,10 +421,9 @@ class TestRepeatOp(utt.InferShapeTester):
self.op = RepeatOp()
# uint64 always fails
# int64 and uint32 also fail if python int are 32-bit
ptr_bitwidth = theano.configdefaults.local_bitwidth()
if ptr_bitwidth == 64:
if LOCAL_BITWIDTH == 64:
self.numpy_unsupported_dtypes = ("uint64",)
if ptr_bitwidth == 32:
if LOCAL_BITWIDTH == 32:
self.numpy_unsupported_dtypes = ("uint32", "int64", "uint64")
def test_repeatOp(self):
......
......@@ -13,6 +13,7 @@ from theano.gradient import (
from theano.scalar import int32 as int_t
from theano.scalar import upcast
from theano.tensor import basic, nlinalg
from theano.utils import LOCAL_BITWIDTH, PYTHON_INT_BITWIDTH
class CpuContiguous(Op):
......@@ -108,10 +109,7 @@ class SearchsortedOp(Op):
return theano.Apply(self, [x, v], [out_type()])
else:
sorter = basic.as_tensor(sorter, ndim=1)
if (
theano.configdefaults.python_int_bitwidth() == 32
and sorter.dtype == "int64"
):
if PYTHON_INT_BITWIDTH == 32 and sorter.dtype == "int64":
raise TypeError(
"numpy.searchsorted with Python 32bit do not support a"
" sorter of int64."
......@@ -658,7 +656,7 @@ class RepeatOp(Op):
# Some dtypes are not supported by numpy's implementation of repeat.
# Until another one is available, we should fail at graph construction
# time, not wait for execution.
ptr_bitwidth = theano.configdefaults.local_bitwidth()
ptr_bitwidth = LOCAL_BITWIDTH
if ptr_bitwidth == 64:
numpy_unsupported_dtypes = ("uint64",)
if ptr_bitwidth == 32:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论