提交 7b86ac5a authored 作者: Michael Osthege's avatar Michael Osthege 提交者: Thomas Wiecki

Repair references to bitwidth constants

上级 a0eb0ad0
...@@ -7,6 +7,7 @@ from theano.compile.function import pfunc ...@@ -7,6 +7,7 @@ from theano.compile.function import pfunc
from theano.compile.io import In from theano.compile.io import In
from theano.compile.sharedvalue import shared from theano.compile.sharedvalue import shared
from theano.tensor import dmatrices, dmatrix, iscalar, lscalar from theano.tensor import dmatrices, dmatrix, iscalar, lscalar
from theano.utils import PYTHON_INT_BITWIDTH
def data_of(s): def data_of(s):
...@@ -570,7 +571,7 @@ class TestPfunc: ...@@ -570,7 +571,7 @@ class TestPfunc:
def test_default_updates_input(self): def test_default_updates_input(self):
x = shared(0) x = shared(0)
y = shared(1) y = shared(1)
if theano.configdefaults.python_int_bitwidth() == 32: if PYTHON_INT_BITWIDTH == 32:
a = iscalar("a") a = iscalar("a")
else: else:
a = lscalar("a") a = lscalar("a")
......
...@@ -4,12 +4,13 @@ import pytest ...@@ -4,12 +4,13 @@ import pytest
import theano import theano
from theano.compile.sharedvalue import SharedVariable, generic, shared from theano.compile.sharedvalue import SharedVariable, generic, shared
from theano.tensor import Tensor, TensorType from theano.tensor import Tensor, TensorType
from theano.utils import PYTHON_INT_BITWIDTH
class TestSharedVariable: class TestSharedVariable:
def test_ctors(self): def test_ctors(self):
if theano.configdefaults.python_int_bitwidth() == 32: if PYTHON_INT_BITWIDTH == 32:
assert shared(7).type == theano.tensor.iscalar, shared(7).type assert shared(7).type == theano.tensor.iscalar, shared(7).type
else: else:
assert shared(7).type == theano.tensor.lscalar, shared(7).type assert shared(7).type == theano.tensor.lscalar, shared(7).type
......
...@@ -201,6 +201,7 @@ from theano.tensor import ( ...@@ -201,6 +201,7 @@ from theano.tensor import (
wvector, wvector,
zvector, zvector,
) )
from theano.utils import PYTHON_INT_BITWIDTH
if config.mode == "FAST_COMPILE": if config.mode == "FAST_COMPILE":
...@@ -5587,7 +5588,7 @@ class TestArithmeticCast: ...@@ -5587,7 +5588,7 @@ class TestArithmeticCast:
class TestLongTensor: class TestLongTensor:
def test_fit_int64(self): def test_fit_int64(self):
bitwidth = theano.configdefaults.python_int_bitwidth() bitwidth = PYTHON_INT_BITWIDTH
for exponent in range(bitwidth): for exponent in range(bitwidth):
val = 2 ** exponent - 1 val = 2 ** exponent - 1
scalar_ct = constant(val) scalar_ct = constant(val)
......
...@@ -36,6 +36,7 @@ from theano.tensor.extra_ops import ( ...@@ -36,6 +36,7 @@ from theano.tensor.extra_ops import (
to_one_hot, to_one_hot,
unravel_index, unravel_index,
) )
from theano.utils import LOCAL_BITWIDTH, PYTHON_INT_BITWIDTH
def test_cpu_contiguous(): def test_cpu_contiguous():
...@@ -105,7 +106,7 @@ class TestSearchsortedOp(utt.InferShapeTester): ...@@ -105,7 +106,7 @@ class TestSearchsortedOp(utt.InferShapeTester):
def test_searchsortedOp_on_int_sorter(self): def test_searchsortedOp_on_int_sorter(self):
compatible_types = ("int8", "int16", "int32") compatible_types = ("int8", "int16", "int32")
if theano.configdefaults.python_int_bitwidth() == 64: if PYTHON_INT_BITWIDTH == 64:
compatible_types += ("int64",) compatible_types += ("int64",)
# 'uint8', 'uint16', 'uint32', 'uint64') # 'uint8', 'uint16', 'uint32', 'uint64')
for dtype in compatible_types: for dtype in compatible_types:
...@@ -420,10 +421,9 @@ class TestRepeatOp(utt.InferShapeTester): ...@@ -420,10 +421,9 @@ class TestRepeatOp(utt.InferShapeTester):
self.op = RepeatOp() self.op = RepeatOp()
# uint64 always fails # uint64 always fails
# int64 and uint32 also fail if python int are 32-bit # int64 and uint32 also fail if python int are 32-bit
ptr_bitwidth = theano.configdefaults.local_bitwidth() if LOCAL_BITWIDTH == 64:
if ptr_bitwidth == 64:
self.numpy_unsupported_dtypes = ("uint64",) self.numpy_unsupported_dtypes = ("uint64",)
if ptr_bitwidth == 32: if LOCAL_BITWIDTH == 32:
self.numpy_unsupported_dtypes = ("uint32", "int64", "uint64") self.numpy_unsupported_dtypes = ("uint32", "int64", "uint64")
def test_repeatOp(self): def test_repeatOp(self):
......
...@@ -13,6 +13,7 @@ from theano.gradient import ( ...@@ -13,6 +13,7 @@ from theano.gradient import (
from theano.scalar import int32 as int_t from theano.scalar import int32 as int_t
from theano.scalar import upcast from theano.scalar import upcast
from theano.tensor import basic, nlinalg from theano.tensor import basic, nlinalg
from theano.utils import LOCAL_BITWIDTH, PYTHON_INT_BITWIDTH
class CpuContiguous(Op): class CpuContiguous(Op):
...@@ -108,10 +109,7 @@ class SearchsortedOp(Op): ...@@ -108,10 +109,7 @@ class SearchsortedOp(Op):
return theano.Apply(self, [x, v], [out_type()]) return theano.Apply(self, [x, v], [out_type()])
else: else:
sorter = basic.as_tensor(sorter, ndim=1) sorter = basic.as_tensor(sorter, ndim=1)
if ( if PYTHON_INT_BITWIDTH == 32 and sorter.dtype == "int64":
theano.configdefaults.python_int_bitwidth() == 32
and sorter.dtype == "int64"
):
raise TypeError( raise TypeError(
"numpy.searchsorted with Python 32bit do not support a" "numpy.searchsorted with Python 32bit do not support a"
" sorter of int64." " sorter of int64."
...@@ -658,7 +656,7 @@ class RepeatOp(Op): ...@@ -658,7 +656,7 @@ class RepeatOp(Op):
# Some dtypes are not supported by numpy's implementation of repeat. # Some dtypes are not supported by numpy's implementation of repeat.
# Until another one is available, we should fail at graph construction # Until another one is available, we should fail at graph construction
# time, not wait for execution. # time, not wait for execution.
ptr_bitwidth = theano.configdefaults.local_bitwidth() ptr_bitwidth = LOCAL_BITWIDTH
if ptr_bitwidth == 64: if ptr_bitwidth == 64:
numpy_unsupported_dtypes = ("uint64",) numpy_unsupported_dtypes = ("uint64",)
if ptr_bitwidth == 32: if ptr_bitwidth == 32:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论