提交 77835528 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Thomas Wiecki

Move exceptions defined in theano.tensor to theano.tensor.exceptions

上级 862aec54
......@@ -108,7 +108,8 @@ TODO: talk about OPTIMIZATION STAGES
.. testcode::
from theano.tensor.opt import get_scalar_constant_value, NotScalarConstantError
from theano.tensor.basic import get_scalar_constant_value
from theano.tensor.exceptions import NotScalarConstantError
# Remove any fibby(zeros(...))
@theano.tensor.opt.register_specialize
......
......@@ -47,4 +47,4 @@ lines_after_imports = 2
lines_between_sections = 1
honor_noqa = True
skip_gitignore = True
skip = theano/version.py
\ No newline at end of file
skip = theano/version.py, **/__init__.py
\ No newline at end of file
......@@ -7,7 +7,8 @@ import theano
import theano.tensor as tt
from tests import unittest_tools as utt
from theano.compile.mode import Mode
from theano.tensor.basic import NotScalarConstantError, _allclose
from theano.tensor.basic import _allclose
from theano.tensor.exceptions import NotScalarConstantError
from theano.tensor.nnet import conv, conv2d
from theano.tensor.type import dmatrix, dtensor3, dtensor4, dvector, scalar, tensor4
......
......@@ -112,13 +112,13 @@ if (
def get_scalar_constant_value(v):
"""Return the constant scalar (i.e. 0-D) value underlying variable `v`.
If v is the output of dimshuffles, fills, allocs, rebroadcasts, cast
If `v` is the output of dim-shuffles, fills, allocs, rebroadcasts, cast
this function digs through them.
If theano.sparse is also there, we will look over CSM op.
If ``theano.sparse`` is also there, we will look over CSM `Op`.
If `v` is not some view of constant data, then raise a
tensor.basic.NotScalarConstantError.
`NotScalarConstantError`.
"""
# Is it necessary to test for presence of theano.sparse at runtime?
sparse = globals().get("sparse")
......
......@@ -894,7 +894,7 @@ class SpecifyShape(COp):
s = theano.tensor.get_scalar_constant_value(node.inputs[1][dim])
s = theano.tensor.as_tensor_variable(s)
new_shape.append(s)
except theano.tensor.basic.NotScalarConstantError:
except theano.tensor.exceptions.NotScalarConstantError:
new_shape.append(node.inputs[1][dim])
assert len(new_shape) == len(xshape)
......
......@@ -37,7 +37,8 @@ from theano.scalar import bool as bool_t
from theano.scalar import constant, get_scalar_type
from theano.scalar import int32 as int_t
from theano.scalar import uint32 as uint32_t
from theano.tensor.basic import NotScalarConstantError, as_tensor_variable
from theano.tensor.basic import as_tensor_variable
from theano.tensor.exceptions import NotScalarConstantError
from theano.tensor.extra_ops import cpu_contiguous
from theano.tensor.nnet.abstract_conv import (
AbstractConv2d,
......
......@@ -24,7 +24,8 @@ from theano.gpuarray.type import GpuArrayType
from theano.graph.basic import Apply
from theano.graph.op import _NoPythonOp
from theano.scalar import as_scalar
from theano.tensor.basic import NotScalarConstantError, get_scalar_constant_value
from theano.tensor.basic import get_scalar_constant_value
from theano.tensor.exceptions import NotScalarConstantError
class GPUAMultinomialFromUniform(GpuKernelBaseCOp, _NoPythonOp):
......
......@@ -16,8 +16,9 @@ from theano.gpuarray.type import GpuArrayType, get_context, move_to_gpu
from theano.graph.basic import Constant
from theano.graph.op import Op
from theano.graph.opt import copy_stack_trace, inherit_stack_trace, local_optimizer
from theano.tensor.basic import NotScalarConstantError, get_scalar_constant_value
from theano.tensor.basic import get_scalar_constant_value
from theano.tensor.elemwise import DimShuffle
from theano.tensor.exceptions import NotScalarConstantError
from theano.tensor.type import TensorType
......
......@@ -2114,7 +2114,7 @@ def _is_zero(x):
try:
constant_value = theano.get_scalar_constant_value(x)
no_constant_value = False
except theano.tensor.basic.NotScalarConstantError:
except theano.tensor.exceptions.NotScalarConstantError:
pass
if no_constant_value:
......
......@@ -8,6 +8,7 @@ from theano.tensor import basic as tt
from theano.tensor.basic import Dot
from theano.tensor.blas import Dot22
from theano.tensor.elemwise import DimShuffle, Prod
from theano.tensor.exceptions import NotScalarConstantError
from theano.tensor.nlinalg import (
MatrixInverse,
det,
......@@ -210,7 +211,7 @@ def is_positive(v):
if v.owner and v.owner.op == tt.pow:
try:
exponent = tt.get_scalar_constant_value(v.owner.inputs[1])
except tt.NotScalarConstantError:
except NotScalarConstantError:
return False
if 0 == exponent % 2:
return True
......
......@@ -27,7 +27,7 @@ from theano.scan import utils
from theano.scan.op import Scan
from theano.scan.utils import safe_new, traverse
from theano.tensor import opt
from theano.tensor.basic import NotScalarConstantError
from theano.tensor.exceptions import NotScalarConstantError
from theano.tensor.type import TensorType, integer_dtypes
from theano.updates import OrderedUpdates
......
......@@ -5,6 +5,7 @@ __docformat__ = "restructuredtext en"
import warnings
import theano.tensor.exceptions
from theano.compile.ops import shape, specify_shape
from theano.gradient import consider_constant, grad, hessian, jacobian
from theano.tensor import sharedvar # adds shared-variable constructors
......
......@@ -23,6 +23,7 @@ from theano.printing import min_informative_str, pprint
from theano.scalar import int32
from theano.tensor import elemwise
from theano.tensor.elemwise import CAReduce, DimShuffle, Elemwise, Sum, scalar_elemwise
from theano.tensor.exceptions import EmptyConstantError, NotScalarConstantError
from theano.tensor.type import (
TensorType,
complex_dtypes,
......@@ -46,10 +47,6 @@ _logger = logging.getLogger("theano.tensor.basic")
__docformat__ = "restructuredtext en"
class ShapeError(Exception):
"""Raised when the shape cannot be computed."""
def check_equal_numpy(x, y):
"""
Return True iff x and y are equal.
......@@ -333,20 +330,6 @@ def _allclose(a, b, rtol=None, atol=None):
return np.allclose(a, b, atol=atol_, rtol=rtol_)
class NotScalarConstantError(Exception):
"""
Raised by get_scalar_constant_value if called on something that is
not a scalar constant.
"""
class EmptyConstantError(NotScalarConstantError):
"""
Raised by get_scalar_const_value if called on something that is a
zero dimensional constant.
"""
def numpy_scalar(data):
"""Return a scalar stored in a numpy ndarray.
......@@ -1867,7 +1850,7 @@ def round(a, mode=None):
elif mode == "half_to_even":
return round_half_to_even(a)
else:
raise Exception(f"round mode {mode} is not implemented.")
raise NotImplementedError(f"round mode {mode} is not implemented.")
@scalar_elemwise
......
......@@ -162,6 +162,7 @@ from theano.scalar import bool as bool_t
from theano.tensor import basic as tt
from theano.tensor.blas_headers import blas_header_text, blas_header_version
from theano.tensor.elemwise import DimShuffle, Elemwise
from theano.tensor.exceptions import NotScalarConstantError
from theano.tensor.opt import in2out, local_dimshuffle_lift
from theano.tensor.type import integer_dtypes, values_eq_approx_remove_inf_nan
from theano.utils import memoize
......@@ -1729,7 +1730,7 @@ def local_gemm_to_ger(fgraph, node):
yv = y.dimshuffle(1)
try:
bval = tt.get_scalar_constant_value(b)
except tt.NotScalarConstantError:
except NotScalarConstantError:
# b isn't a constant, GEMM is doing useful pre-scaling
return
......
class ShapeError(Exception):
"""Raised when the shape cannot be computed."""
class NotScalarConstantError(Exception):
"""
Raised by get_scalar_constant_value if called on something that is
not a scalar constant.
"""
class EmptyConstantError(NotScalarConstantError):
"""
Raised by get_scalar_const_value if called on something that is a
zero dimensional constant.
"""
class AdvancedIndexingError(TypeError):
"""
Raised when Subtensor is asked to perform advanced indexing.
"""
......@@ -19,6 +19,7 @@ from theano.scalar import int32 as int_t
from theano.scalar import upcast
from theano.tensor import basic as tt
from theano.tensor import nlinalg
from theano.tensor.exceptions import NotScalarConstantError
from theano.tensor.subtensor import advanced_inc_subtensor1, set_subtensor
from theano.tensor.type import (
TensorType,
......@@ -690,7 +691,7 @@ class RepeatOp(Op):
else:
try:
const_reps = tt.get_scalar_constant_value(repeats)
except tt.NotScalarConstantError:
except NotScalarConstantError:
const_reps = None
if const_reps == 1:
broadcastable = x.broadcastable
......
......@@ -483,7 +483,7 @@ class ConvOp(OpenMPOp):
all_shape = self.has_all_shape(imshp, kshp, nkern, bsize)
if (unroll_batch or unroll_kern) and not all_shape:
raise Exception(
raise ValueError(
"In ConvOp, when using unroll_batch and"
" unroll_nkern, all shape are needed"
)
......@@ -613,10 +613,10 @@ class ConvOp(OpenMPOp):
self.out_mode = output_mode
if self.out_mode not in ["valid", "full"]:
raise Exception(f"Mode {self.out_mode} not implemented")
raise NotImplementedError(f"Mode {self.out_mode} not implemented")
if any((shp is not None) and (shp <= 0) for shp in self.outshp):
raise Exception(
raise ValueError(
"Bad size for the output shape. Verify that [post-"
f"supersampling] input shape ({self.imshp_logical}) and kern"
f" shape({self.kshp_logical}) are ok. (Hint: kerns must fit inside"
......@@ -1038,7 +1038,7 @@ class ConvOp(OpenMPOp):
all_shape = self.has_all_shape(self.imshp, self.kshp, self.nkern, self.bsize)
if not all_shape and (self.dx != 1 or self.dy != 1):
raise Exception(
raise ValueError(
"ConvOp.grad when dx!=1 or dy!=1 we must have all "
"the optional shape information"
)
......@@ -1486,7 +1486,9 @@ if(kerns_dim[1] != img2d_dim[1]){
elif node.inputs[0].type.dtype == "float64":
d["type"] = "double"
else:
raise Exception(f"Type {node.inputs[0].type.dtype} not implemented")
raise NotImplementedError(
f"Type {node.inputs[0].type.dtype} not implemented"
)
d["gemm"] = "dgemm_"
if not d["type"] == "double":
d["gemm"] = "sgemm_"
......@@ -2042,8 +2044,8 @@ def gen_conv_code_unroll_batch_kern(d, unroll_bsize=1, unroll_ksize=1):
or "unroll_biter" in d
or "unroll_kiter" in d
):
raise Exception(
"We can't use this dictionary as we will overwrite some of its containt"
raise ValueError(
"We can't use this dictionary as we will overwrite some of its content"
)
d = d.copy()
......
......@@ -34,6 +34,7 @@ from theano.tensor import basic as tt
from theano.tensor import extra_ops, opt
from theano.tensor.basic import MaxAndArgmax, as_tensor_variable, log
from theano.tensor.elemwise import Elemwise
from theano.tensor.exceptions import NotScalarConstantError
from theano.tensor.nnet.blocksparse import sparse_block_dot
from theano.tensor.nnet.sigm import sigmoid, softplus
from theano.tensor.opt import (
......@@ -1758,7 +1759,7 @@ def _check_rows_is_arange_len_labels(fgraph, rows, labels):
def _is_const(z, val, approx=False):
try:
maybe = opt.get_scalar_constant_value(z)
except tt.NotScalarConstantError:
except NotScalarConstantError:
return False
if approx:
return np.allclose(maybe, val)
......
......@@ -20,8 +20,8 @@ from theano.graph.utils import MethodNotDefined
from theano.printing import pprint
from theano.tensor import basic as tt
from theano.tensor import opt
from theano.tensor.basic import NotScalarConstantError
from theano.tensor.elemwise import Elemwise
from theano.tensor.exceptions import NotScalarConstantError
from theano.tensor.type import TensorType, values_eq_approx_remove_inf
......@@ -463,7 +463,7 @@ def _is_1(expr):
try:
v = opt.get_scalar_constant_value(expr)
return np.allclose(v, 1)
except tt.NotScalarConstantError:
except NotScalarConstantError:
return False
......@@ -1074,7 +1074,7 @@ def local_1msigmoid(fgraph, node):
if sub_r.owner and sub_r.owner.op == sigmoid:
try:
val_l = opt.get_scalar_constant_value(sub_l)
except tt.NotScalarConstantError:
except NotScalarConstantError:
return
if np.allclose(np.sum(val_l), 1):
out = sigmoid(-sub_r.owner.inputs[0])
......
......@@ -62,11 +62,9 @@ from theano.tensor.basic import (
Dot,
Flatten,
Join,
NotScalarConstantError,
Rebroadcast,
Reshape,
ScalarFromTensor,
ShapeError,
Split,
TensorFromScalar,
Tile,
......@@ -99,6 +97,7 @@ from theano.tensor.elemwise import (
ProdWithoutZeros,
Sum,
)
from theano.tensor.exceptions import NotScalarConstantError, ShapeError
from theano.tensor.extra_ops import broadcast_shape
from theano.tensor.sort import TopKOp
from theano.tensor.subtensor import (
......@@ -1115,7 +1114,7 @@ class ShapeFeature(toolbox.Feature):
"Code called by infer_shape failed raising a "
"NotImplementedError. Raising NotImplementedError to "
"indicate that a shape cannot be computed is no longer "
"supported, and one should now use tensor.ShapeError "
"supported, and one should now use ShapeError "
f"instead. The original exception message is: {e}"
).with_traceback(e.__traceback__)
except Exception as e:
......
......@@ -26,6 +26,7 @@ from theano.tensor.basic import (
get_scalar_constant_value,
)
from theano.tensor.elemwise import DimShuffle
from theano.tensor.exceptions import AdvancedIndexingError, ShapeError
from theano.tensor.inc_code import inc_code
from theano.tensor.type import (
bscalar,
......@@ -62,13 +63,6 @@ invalid_tensor_types = (
)
class AdvancedIndexingError(TypeError):
"""
Raised when Subtensor is asked to perform advanced indexing.
"""
def as_index_constant(a):
"""Convert Python literals to Theano constants--when possible--in Subtensor arguments.
......@@ -2344,7 +2338,7 @@ class AdvancedSubtensor(Op):
isinstance(idx, (np.bool_, bool))
or getattr(idx, "dtype", None) == "bool"
):
raise theano.tensor.basic.ShapeError(
raise ShapeError(
"Shape inference for boolean indices is not implemented"
)
# The `ishapes` entries for `SliceType`s will be None, and
......
......@@ -9,6 +9,7 @@ import theano
from theano.configdefaults import config
from theano.graph.basic import Constant, Variable
from theano.scalar import ComplexError, IntegerDivisionError
from theano.tensor.exceptions import AdvancedIndexingError
from theano.tensor.type import TensorType
from theano.tensor.utils import hash_from_ndarray
......@@ -533,7 +534,7 @@ class _tensor_py_operators:
if arg is not np.newaxis:
try:
theano.tensor.subtensor.Subtensor.convert(arg)
except theano.tensor.subtensor.AdvancedIndexingError:
except AdvancedIndexingError:
if advanced:
axis = None
break
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论