提交 77835528 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Thomas Wiecki

Move exceptions defined in theano.tensor to theano.tensor.exceptions

上级 862aec54
...@@ -108,7 +108,8 @@ TODO: talk about OPTIMIZATION STAGES ...@@ -108,7 +108,8 @@ TODO: talk about OPTIMIZATION STAGES
.. testcode:: .. testcode::
from theano.tensor.opt import get_scalar_constant_value, NotScalarConstantError from theano.tensor.basic import get_scalar_constant_value
from theano.tensor.exceptions import NotScalarConstantError
# Remove any fibby(zeros(...)) # Remove any fibby(zeros(...))
@theano.tensor.opt.register_specialize @theano.tensor.opt.register_specialize
......
...@@ -47,4 +47,4 @@ lines_after_imports = 2 ...@@ -47,4 +47,4 @@ lines_after_imports = 2
lines_between_sections = 1 lines_between_sections = 1
honor_noqa = True honor_noqa = True
skip_gitignore = True skip_gitignore = True
skip = theano/version.py skip = theano/version.py, **/__init__.py
\ No newline at end of file \ No newline at end of file
...@@ -7,7 +7,8 @@ import theano ...@@ -7,7 +7,8 @@ import theano
import theano.tensor as tt import theano.tensor as tt
from tests import unittest_tools as utt from tests import unittest_tools as utt
from theano.compile.mode import Mode from theano.compile.mode import Mode
from theano.tensor.basic import NotScalarConstantError, _allclose from theano.tensor.basic import _allclose
from theano.tensor.exceptions import NotScalarConstantError
from theano.tensor.nnet import conv, conv2d from theano.tensor.nnet import conv, conv2d
from theano.tensor.type import dmatrix, dtensor3, dtensor4, dvector, scalar, tensor4 from theano.tensor.type import dmatrix, dtensor3, dtensor4, dvector, scalar, tensor4
......
...@@ -112,13 +112,13 @@ if ( ...@@ -112,13 +112,13 @@ if (
def get_scalar_constant_value(v): def get_scalar_constant_value(v):
"""Return the constant scalar (i.e. 0-D) value underlying variable `v`. """Return the constant scalar (i.e. 0-D) value underlying variable `v`.
If v is the output of dimshuffles, fills, allocs, rebroadcasts, cast If `v` is the output of dim-shuffles, fills, allocs, rebroadcasts, cast
this function digs through them. this function digs through them.
If theano.sparse is also there, we will look over CSM op. If ``theano.sparse`` is also there, we will look over CSM `Op`.
If `v` is not some view of constant data, then raise a If `v` is not some view of constant data, then raise a
tensor.basic.NotScalarConstantError. `NotScalarConstantError`.
""" """
# Is it necessary to test for presence of theano.sparse at runtime? # Is it necessary to test for presence of theano.sparse at runtime?
sparse = globals().get("sparse") sparse = globals().get("sparse")
......
...@@ -894,7 +894,7 @@ class SpecifyShape(COp): ...@@ -894,7 +894,7 @@ class SpecifyShape(COp):
s = theano.tensor.get_scalar_constant_value(node.inputs[1][dim]) s = theano.tensor.get_scalar_constant_value(node.inputs[1][dim])
s = theano.tensor.as_tensor_variable(s) s = theano.tensor.as_tensor_variable(s)
new_shape.append(s) new_shape.append(s)
except theano.tensor.basic.NotScalarConstantError: except theano.tensor.exceptions.NotScalarConstantError:
new_shape.append(node.inputs[1][dim]) new_shape.append(node.inputs[1][dim])
assert len(new_shape) == len(xshape) assert len(new_shape) == len(xshape)
......
...@@ -37,7 +37,8 @@ from theano.scalar import bool as bool_t ...@@ -37,7 +37,8 @@ from theano.scalar import bool as bool_t
from theano.scalar import constant, get_scalar_type from theano.scalar import constant, get_scalar_type
from theano.scalar import int32 as int_t from theano.scalar import int32 as int_t
from theano.scalar import uint32 as uint32_t from theano.scalar import uint32 as uint32_t
from theano.tensor.basic import NotScalarConstantError, as_tensor_variable from theano.tensor.basic import as_tensor_variable
from theano.tensor.exceptions import NotScalarConstantError
from theano.tensor.extra_ops import cpu_contiguous from theano.tensor.extra_ops import cpu_contiguous
from theano.tensor.nnet.abstract_conv import ( from theano.tensor.nnet.abstract_conv import (
AbstractConv2d, AbstractConv2d,
......
...@@ -24,7 +24,8 @@ from theano.gpuarray.type import GpuArrayType ...@@ -24,7 +24,8 @@ from theano.gpuarray.type import GpuArrayType
from theano.graph.basic import Apply from theano.graph.basic import Apply
from theano.graph.op import _NoPythonOp from theano.graph.op import _NoPythonOp
from theano.scalar import as_scalar from theano.scalar import as_scalar
from theano.tensor.basic import NotScalarConstantError, get_scalar_constant_value from theano.tensor.basic import get_scalar_constant_value
from theano.tensor.exceptions import NotScalarConstantError
class GPUAMultinomialFromUniform(GpuKernelBaseCOp, _NoPythonOp): class GPUAMultinomialFromUniform(GpuKernelBaseCOp, _NoPythonOp):
......
...@@ -16,8 +16,9 @@ from theano.gpuarray.type import GpuArrayType, get_context, move_to_gpu ...@@ -16,8 +16,9 @@ from theano.gpuarray.type import GpuArrayType, get_context, move_to_gpu
from theano.graph.basic import Constant from theano.graph.basic import Constant
from theano.graph.op import Op from theano.graph.op import Op
from theano.graph.opt import copy_stack_trace, inherit_stack_trace, local_optimizer from theano.graph.opt import copy_stack_trace, inherit_stack_trace, local_optimizer
from theano.tensor.basic import NotScalarConstantError, get_scalar_constant_value from theano.tensor.basic import get_scalar_constant_value
from theano.tensor.elemwise import DimShuffle from theano.tensor.elemwise import DimShuffle
from theano.tensor.exceptions import NotScalarConstantError
from theano.tensor.type import TensorType from theano.tensor.type import TensorType
......
...@@ -2114,7 +2114,7 @@ def _is_zero(x): ...@@ -2114,7 +2114,7 @@ def _is_zero(x):
try: try:
constant_value = theano.get_scalar_constant_value(x) constant_value = theano.get_scalar_constant_value(x)
no_constant_value = False no_constant_value = False
except theano.tensor.basic.NotScalarConstantError: except theano.tensor.exceptions.NotScalarConstantError:
pass pass
if no_constant_value: if no_constant_value:
......
...@@ -8,6 +8,7 @@ from theano.tensor import basic as tt ...@@ -8,6 +8,7 @@ from theano.tensor import basic as tt
from theano.tensor.basic import Dot from theano.tensor.basic import Dot
from theano.tensor.blas import Dot22 from theano.tensor.blas import Dot22
from theano.tensor.elemwise import DimShuffle, Prod from theano.tensor.elemwise import DimShuffle, Prod
from theano.tensor.exceptions import NotScalarConstantError
from theano.tensor.nlinalg import ( from theano.tensor.nlinalg import (
MatrixInverse, MatrixInverse,
det, det,
...@@ -210,7 +211,7 @@ def is_positive(v): ...@@ -210,7 +211,7 @@ def is_positive(v):
if v.owner and v.owner.op == tt.pow: if v.owner and v.owner.op == tt.pow:
try: try:
exponent = tt.get_scalar_constant_value(v.owner.inputs[1]) exponent = tt.get_scalar_constant_value(v.owner.inputs[1])
except tt.NotScalarConstantError: except NotScalarConstantError:
return False return False
if 0 == exponent % 2: if 0 == exponent % 2:
return True return True
......
...@@ -27,7 +27,7 @@ from theano.scan import utils ...@@ -27,7 +27,7 @@ from theano.scan import utils
from theano.scan.op import Scan from theano.scan.op import Scan
from theano.scan.utils import safe_new, traverse from theano.scan.utils import safe_new, traverse
from theano.tensor import opt from theano.tensor import opt
from theano.tensor.basic import NotScalarConstantError from theano.tensor.exceptions import NotScalarConstantError
from theano.tensor.type import TensorType, integer_dtypes from theano.tensor.type import TensorType, integer_dtypes
from theano.updates import OrderedUpdates from theano.updates import OrderedUpdates
......
...@@ -5,6 +5,7 @@ __docformat__ = "restructuredtext en" ...@@ -5,6 +5,7 @@ __docformat__ = "restructuredtext en"
import warnings import warnings
import theano.tensor.exceptions
from theano.compile.ops import shape, specify_shape from theano.compile.ops import shape, specify_shape
from theano.gradient import consider_constant, grad, hessian, jacobian from theano.gradient import consider_constant, grad, hessian, jacobian
from theano.tensor import sharedvar # adds shared-variable constructors from theano.tensor import sharedvar # adds shared-variable constructors
......
...@@ -23,6 +23,7 @@ from theano.printing import min_informative_str, pprint ...@@ -23,6 +23,7 @@ from theano.printing import min_informative_str, pprint
from theano.scalar import int32 from theano.scalar import int32
from theano.tensor import elemwise from theano.tensor import elemwise
from theano.tensor.elemwise import CAReduce, DimShuffle, Elemwise, Sum, scalar_elemwise from theano.tensor.elemwise import CAReduce, DimShuffle, Elemwise, Sum, scalar_elemwise
from theano.tensor.exceptions import EmptyConstantError, NotScalarConstantError
from theano.tensor.type import ( from theano.tensor.type import (
TensorType, TensorType,
complex_dtypes, complex_dtypes,
...@@ -46,10 +47,6 @@ _logger = logging.getLogger("theano.tensor.basic") ...@@ -46,10 +47,6 @@ _logger = logging.getLogger("theano.tensor.basic")
__docformat__ = "restructuredtext en" __docformat__ = "restructuredtext en"
class ShapeError(Exception):
"""Raised when the shape cannot be computed."""
def check_equal_numpy(x, y): def check_equal_numpy(x, y):
""" """
Return True iff x and y are equal. Return True iff x and y are equal.
...@@ -333,20 +330,6 @@ def _allclose(a, b, rtol=None, atol=None): ...@@ -333,20 +330,6 @@ def _allclose(a, b, rtol=None, atol=None):
return np.allclose(a, b, atol=atol_, rtol=rtol_) return np.allclose(a, b, atol=atol_, rtol=rtol_)
class NotScalarConstantError(Exception):
"""
Raised by get_scalar_constant_value if called on something that is
not a scalar constant.
"""
class EmptyConstantError(NotScalarConstantError):
"""
Raised by get_scalar_const_value if called on something that is a
zero dimensional constant.
"""
def numpy_scalar(data): def numpy_scalar(data):
"""Return a scalar stored in a numpy ndarray. """Return a scalar stored in a numpy ndarray.
...@@ -1867,7 +1850,7 @@ def round(a, mode=None): ...@@ -1867,7 +1850,7 @@ def round(a, mode=None):
elif mode == "half_to_even": elif mode == "half_to_even":
return round_half_to_even(a) return round_half_to_even(a)
else: else:
raise Exception(f"round mode {mode} is not implemented.") raise NotImplementedError(f"round mode {mode} is not implemented.")
@scalar_elemwise @scalar_elemwise
......
...@@ -162,6 +162,7 @@ from theano.scalar import bool as bool_t ...@@ -162,6 +162,7 @@ from theano.scalar import bool as bool_t
from theano.tensor import basic as tt from theano.tensor import basic as tt
from theano.tensor.blas_headers import blas_header_text, blas_header_version from theano.tensor.blas_headers import blas_header_text, blas_header_version
from theano.tensor.elemwise import DimShuffle, Elemwise from theano.tensor.elemwise import DimShuffle, Elemwise
from theano.tensor.exceptions import NotScalarConstantError
from theano.tensor.opt import in2out, local_dimshuffle_lift from theano.tensor.opt import in2out, local_dimshuffle_lift
from theano.tensor.type import integer_dtypes, values_eq_approx_remove_inf_nan from theano.tensor.type import integer_dtypes, values_eq_approx_remove_inf_nan
from theano.utils import memoize from theano.utils import memoize
...@@ -1729,7 +1730,7 @@ def local_gemm_to_ger(fgraph, node): ...@@ -1729,7 +1730,7 @@ def local_gemm_to_ger(fgraph, node):
yv = y.dimshuffle(1) yv = y.dimshuffle(1)
try: try:
bval = tt.get_scalar_constant_value(b) bval = tt.get_scalar_constant_value(b)
except tt.NotScalarConstantError: except NotScalarConstantError:
# b isn't a constant, GEMM is doing useful pre-scaling # b isn't a constant, GEMM is doing useful pre-scaling
return return
......
class ShapeError(Exception):
"""Raised when the shape cannot be computed."""
class NotScalarConstantError(Exception):
"""
Raised by get_scalar_constant_value if called on something that is
not a scalar constant.
"""
class EmptyConstantError(NotScalarConstantError):
"""
Raised by get_scalar_const_value if called on something that is a
zero dimensional constant.
"""
class AdvancedIndexingError(TypeError):
"""
Raised when Subtensor is asked to perform advanced indexing.
"""
...@@ -19,6 +19,7 @@ from theano.scalar import int32 as int_t ...@@ -19,6 +19,7 @@ from theano.scalar import int32 as int_t
from theano.scalar import upcast from theano.scalar import upcast
from theano.tensor import basic as tt from theano.tensor import basic as tt
from theano.tensor import nlinalg from theano.tensor import nlinalg
from theano.tensor.exceptions import NotScalarConstantError
from theano.tensor.subtensor import advanced_inc_subtensor1, set_subtensor from theano.tensor.subtensor import advanced_inc_subtensor1, set_subtensor
from theano.tensor.type import ( from theano.tensor.type import (
TensorType, TensorType,
...@@ -690,7 +691,7 @@ class RepeatOp(Op): ...@@ -690,7 +691,7 @@ class RepeatOp(Op):
else: else:
try: try:
const_reps = tt.get_scalar_constant_value(repeats) const_reps = tt.get_scalar_constant_value(repeats)
except tt.NotScalarConstantError: except NotScalarConstantError:
const_reps = None const_reps = None
if const_reps == 1: if const_reps == 1:
broadcastable = x.broadcastable broadcastable = x.broadcastable
......
...@@ -483,7 +483,7 @@ class ConvOp(OpenMPOp): ...@@ -483,7 +483,7 @@ class ConvOp(OpenMPOp):
all_shape = self.has_all_shape(imshp, kshp, nkern, bsize) all_shape = self.has_all_shape(imshp, kshp, nkern, bsize)
if (unroll_batch or unroll_kern) and not all_shape: if (unroll_batch or unroll_kern) and not all_shape:
raise Exception( raise ValueError(
"In ConvOp, when using unroll_batch and" "In ConvOp, when using unroll_batch and"
" unroll_nkern, all shape are needed" " unroll_nkern, all shape are needed"
) )
...@@ -613,10 +613,10 @@ class ConvOp(OpenMPOp): ...@@ -613,10 +613,10 @@ class ConvOp(OpenMPOp):
self.out_mode = output_mode self.out_mode = output_mode
if self.out_mode not in ["valid", "full"]: if self.out_mode not in ["valid", "full"]:
raise Exception(f"Mode {self.out_mode} not implemented") raise NotImplementedError(f"Mode {self.out_mode} not implemented")
if any((shp is not None) and (shp <= 0) for shp in self.outshp): if any((shp is not None) and (shp <= 0) for shp in self.outshp):
raise Exception( raise ValueError(
"Bad size for the output shape. Verify that [post-" "Bad size for the output shape. Verify that [post-"
f"supersampling] input shape ({self.imshp_logical}) and kern" f"supersampling] input shape ({self.imshp_logical}) and kern"
f" shape({self.kshp_logical}) are ok. (Hint: kerns must fit inside" f" shape({self.kshp_logical}) are ok. (Hint: kerns must fit inside"
...@@ -1038,7 +1038,7 @@ class ConvOp(OpenMPOp): ...@@ -1038,7 +1038,7 @@ class ConvOp(OpenMPOp):
all_shape = self.has_all_shape(self.imshp, self.kshp, self.nkern, self.bsize) all_shape = self.has_all_shape(self.imshp, self.kshp, self.nkern, self.bsize)
if not all_shape and (self.dx != 1 or self.dy != 1): if not all_shape and (self.dx != 1 or self.dy != 1):
raise Exception( raise ValueError(
"ConvOp.grad when dx!=1 or dy!=1 we must have all " "ConvOp.grad when dx!=1 or dy!=1 we must have all "
"the optional shape information" "the optional shape information"
) )
...@@ -1486,7 +1486,9 @@ if(kerns_dim[1] != img2d_dim[1]){ ...@@ -1486,7 +1486,9 @@ if(kerns_dim[1] != img2d_dim[1]){
elif node.inputs[0].type.dtype == "float64": elif node.inputs[0].type.dtype == "float64":
d["type"] = "double" d["type"] = "double"
else: else:
raise Exception(f"Type {node.inputs[0].type.dtype} not implemented") raise NotImplementedError(
f"Type {node.inputs[0].type.dtype} not implemented"
)
d["gemm"] = "dgemm_" d["gemm"] = "dgemm_"
if not d["type"] == "double": if not d["type"] == "double":
d["gemm"] = "sgemm_" d["gemm"] = "sgemm_"
...@@ -2042,8 +2044,8 @@ def gen_conv_code_unroll_batch_kern(d, unroll_bsize=1, unroll_ksize=1): ...@@ -2042,8 +2044,8 @@ def gen_conv_code_unroll_batch_kern(d, unroll_bsize=1, unroll_ksize=1):
or "unroll_biter" in d or "unroll_biter" in d
or "unroll_kiter" in d or "unroll_kiter" in d
): ):
raise Exception( raise ValueError(
"We can't use this dictionary as we will overwrite some of its containt" "We can't use this dictionary as we will overwrite some of its content"
) )
d = d.copy() d = d.copy()
......
...@@ -34,6 +34,7 @@ from theano.tensor import basic as tt ...@@ -34,6 +34,7 @@ from theano.tensor import basic as tt
from theano.tensor import extra_ops, opt from theano.tensor import extra_ops, opt
from theano.tensor.basic import MaxAndArgmax, as_tensor_variable, log from theano.tensor.basic import MaxAndArgmax, as_tensor_variable, log
from theano.tensor.elemwise import Elemwise from theano.tensor.elemwise import Elemwise
from theano.tensor.exceptions import NotScalarConstantError
from theano.tensor.nnet.blocksparse import sparse_block_dot from theano.tensor.nnet.blocksparse import sparse_block_dot
from theano.tensor.nnet.sigm import sigmoid, softplus from theano.tensor.nnet.sigm import sigmoid, softplus
from theano.tensor.opt import ( from theano.tensor.opt import (
...@@ -1758,7 +1759,7 @@ def _check_rows_is_arange_len_labels(fgraph, rows, labels): ...@@ -1758,7 +1759,7 @@ def _check_rows_is_arange_len_labels(fgraph, rows, labels):
def _is_const(z, val, approx=False): def _is_const(z, val, approx=False):
try: try:
maybe = opt.get_scalar_constant_value(z) maybe = opt.get_scalar_constant_value(z)
except tt.NotScalarConstantError: except NotScalarConstantError:
return False return False
if approx: if approx:
return np.allclose(maybe, val) return np.allclose(maybe, val)
......
...@@ -20,8 +20,8 @@ from theano.graph.utils import MethodNotDefined ...@@ -20,8 +20,8 @@ from theano.graph.utils import MethodNotDefined
from theano.printing import pprint from theano.printing import pprint
from theano.tensor import basic as tt from theano.tensor import basic as tt
from theano.tensor import opt from theano.tensor import opt
from theano.tensor.basic import NotScalarConstantError
from theano.tensor.elemwise import Elemwise from theano.tensor.elemwise import Elemwise
from theano.tensor.exceptions import NotScalarConstantError
from theano.tensor.type import TensorType, values_eq_approx_remove_inf from theano.tensor.type import TensorType, values_eq_approx_remove_inf
...@@ -463,7 +463,7 @@ def _is_1(expr): ...@@ -463,7 +463,7 @@ def _is_1(expr):
try: try:
v = opt.get_scalar_constant_value(expr) v = opt.get_scalar_constant_value(expr)
return np.allclose(v, 1) return np.allclose(v, 1)
except tt.NotScalarConstantError: except NotScalarConstantError:
return False return False
...@@ -1074,7 +1074,7 @@ def local_1msigmoid(fgraph, node): ...@@ -1074,7 +1074,7 @@ def local_1msigmoid(fgraph, node):
if sub_r.owner and sub_r.owner.op == sigmoid: if sub_r.owner and sub_r.owner.op == sigmoid:
try: try:
val_l = opt.get_scalar_constant_value(sub_l) val_l = opt.get_scalar_constant_value(sub_l)
except tt.NotScalarConstantError: except NotScalarConstantError:
return return
if np.allclose(np.sum(val_l), 1): if np.allclose(np.sum(val_l), 1):
out = sigmoid(-sub_r.owner.inputs[0]) out = sigmoid(-sub_r.owner.inputs[0])
......
...@@ -62,11 +62,9 @@ from theano.tensor.basic import ( ...@@ -62,11 +62,9 @@ from theano.tensor.basic import (
Dot, Dot,
Flatten, Flatten,
Join, Join,
NotScalarConstantError,
Rebroadcast, Rebroadcast,
Reshape, Reshape,
ScalarFromTensor, ScalarFromTensor,
ShapeError,
Split, Split,
TensorFromScalar, TensorFromScalar,
Tile, Tile,
...@@ -99,6 +97,7 @@ from theano.tensor.elemwise import ( ...@@ -99,6 +97,7 @@ from theano.tensor.elemwise import (
ProdWithoutZeros, ProdWithoutZeros,
Sum, Sum,
) )
from theano.tensor.exceptions import NotScalarConstantError, ShapeError
from theano.tensor.extra_ops import broadcast_shape from theano.tensor.extra_ops import broadcast_shape
from theano.tensor.sort import TopKOp from theano.tensor.sort import TopKOp
from theano.tensor.subtensor import ( from theano.tensor.subtensor import (
...@@ -1115,7 +1114,7 @@ class ShapeFeature(toolbox.Feature): ...@@ -1115,7 +1114,7 @@ class ShapeFeature(toolbox.Feature):
"Code called by infer_shape failed raising a " "Code called by infer_shape failed raising a "
"NotImplementedError. Raising NotImplementedError to " "NotImplementedError. Raising NotImplementedError to "
"indicate that a shape cannot be computed is no longer " "indicate that a shape cannot be computed is no longer "
"supported, and one should now use tensor.ShapeError " "supported, and one should now use ShapeError "
f"instead. The original exception message is: {e}" f"instead. The original exception message is: {e}"
).with_traceback(e.__traceback__) ).with_traceback(e.__traceback__)
except Exception as e: except Exception as e:
......
...@@ -26,6 +26,7 @@ from theano.tensor.basic import ( ...@@ -26,6 +26,7 @@ from theano.tensor.basic import (
get_scalar_constant_value, get_scalar_constant_value,
) )
from theano.tensor.elemwise import DimShuffle from theano.tensor.elemwise import DimShuffle
from theano.tensor.exceptions import AdvancedIndexingError, ShapeError
from theano.tensor.inc_code import inc_code from theano.tensor.inc_code import inc_code
from theano.tensor.type import ( from theano.tensor.type import (
bscalar, bscalar,
...@@ -62,13 +63,6 @@ invalid_tensor_types = ( ...@@ -62,13 +63,6 @@ invalid_tensor_types = (
) )
class AdvancedIndexingError(TypeError):
"""
Raised when Subtensor is asked to perform advanced indexing.
"""
def as_index_constant(a): def as_index_constant(a):
"""Convert Python literals to Theano constants--when possible--in Subtensor arguments. """Convert Python literals to Theano constants--when possible--in Subtensor arguments.
...@@ -2344,7 +2338,7 @@ class AdvancedSubtensor(Op): ...@@ -2344,7 +2338,7 @@ class AdvancedSubtensor(Op):
isinstance(idx, (np.bool_, bool)) isinstance(idx, (np.bool_, bool))
or getattr(idx, "dtype", None) == "bool" or getattr(idx, "dtype", None) == "bool"
): ):
raise theano.tensor.basic.ShapeError( raise ShapeError(
"Shape inference for boolean indices is not implemented" "Shape inference for boolean indices is not implemented"
) )
# The `ishapes` entries for `SliceType`s will be None, and # The `ishapes` entries for `SliceType`s will be None, and
......
...@@ -9,6 +9,7 @@ import theano ...@@ -9,6 +9,7 @@ import theano
from theano.configdefaults import config from theano.configdefaults import config
from theano.graph.basic import Constant, Variable from theano.graph.basic import Constant, Variable
from theano.scalar import ComplexError, IntegerDivisionError from theano.scalar import ComplexError, IntegerDivisionError
from theano.tensor.exceptions import AdvancedIndexingError
from theano.tensor.type import TensorType from theano.tensor.type import TensorType
from theano.tensor.utils import hash_from_ndarray from theano.tensor.utils import hash_from_ndarray
...@@ -533,7 +534,7 @@ class _tensor_py_operators: ...@@ -533,7 +534,7 @@ class _tensor_py_operators:
if arg is not np.newaxis: if arg is not np.newaxis:
try: try:
theano.tensor.subtensor.Subtensor.convert(arg) theano.tensor.subtensor.Subtensor.convert(arg)
except theano.tensor.subtensor.AdvancedIndexingError: except AdvancedIndexingError:
if advanced: if advanced:
axis = None axis = None
break break
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论