提交 721d2b4e authored 作者: David Warde-Farley's avatar David Warde-Farley 提交者: Arnaud Bergeron

Use six.reraise for advanced exception raising.

上级 402a0a85
...@@ -12,7 +12,7 @@ from itertools import izip ...@@ -12,7 +12,7 @@ from itertools import izip
import numpy import numpy
from theano.compat import PY3 from theano.compat import PY3
from theano.compat.six import string_types from theano.compat.six import string_types, reraise
from theano.compat.six.moves import StringIO, xrange from theano.compat.six.moves import StringIO, xrange
from theano.gof.utils import MethodNotDefined from theano.gof.utils import MethodNotDefined
...@@ -1537,8 +1537,7 @@ class _CThunk(object): ...@@ -1537,8 +1537,7 @@ class _CThunk(object):
' Was the error set in the c code?'), end=' ', file=sys.stderr) ' Was the error set in the c code?'), end=' ', file=sys.stderr)
print(self.error_storage, file=sys.stderr) print(self.error_storage, file=sys.stderr)
raise raise
reraise(exc_type, exc_value, exc_trace)
raise exc_type, exc_value, exc_trace
class OpWiseCLinker(link.LocalLinker): class OpWiseCLinker(link.LocalLinker):
......
...@@ -7,6 +7,8 @@ import traceback ...@@ -7,6 +7,8 @@ import traceback
import numpy import numpy
import theano import theano
from theano.compat import PY3
from theano.compat.six import reraise
from theano.compat.six.moves import StringIO from theano.compat.six.moves import StringIO
from theano.gof import utils from theano.gof import utils
from theano.gof import graph from theano.gof import graph
...@@ -100,7 +102,7 @@ def raise_with_op(node, thunk=None, exc_info=None, storage_map=None): ...@@ -100,7 +102,7 @@ def raise_with_op(node, thunk=None, exc_info=None, storage_map=None):
exc_type, exc_value, exc_trace = exc_info exc_type, exc_value, exc_trace = exc_info
if exc_type == KeyboardInterrupt: if exc_type == KeyboardInterrupt:
# print a simple traceback from KeyboardInterrupt # print a simple traceback from KeyboardInterrupt
raise exc_type, exc_value, exc_trace reraise(exc_type, exc_value, exc_trace)
try: try:
trace = node.outputs[0].tag.trace trace = node.outputs[0].tag.trace
except AttributeError: except AttributeError:
...@@ -290,7 +292,7 @@ def raise_with_op(node, thunk=None, exc_info=None, storage_map=None): ...@@ -290,7 +292,7 @@ def raise_with_op(node, thunk=None, exc_info=None, storage_map=None):
exc_value = exc_type(str(exc_value) + detailed_err_msg + exc_value = exc_type(str(exc_value) + detailed_err_msg +
'\n' + '\n'.join(hints)) '\n' + '\n'.join(hints))
raise exc_type, exc_value, exc_trace reraise(exc_type, exc_value, exc_trace)
class Linker(object): class Linker(object):
......
...@@ -942,10 +942,10 @@ def local_optimizer(tracks, inplace=False): ...@@ -942,10 +942,10 @@ def local_optimizer(tracks, inplace=False):
"""WRITEME""" """WRITEME"""
if tracks is not None: if tracks is not None:
if len(tracks) is 0: if len(tracks) is 0:
raise ValueError, ("Use None instead of an empty list to apply to all nodes.", f.__module__, f.__name__) raise ValueError("Use None instead of an empty list to apply to all nodes.", f.__module__, f.__name__)
for t in tracks: for t in tracks:
if not (isinstance(t, op.Op) or issubclass(t, op.PureOp)): if not (isinstance(t, op.Op) or issubclass(t, op.PureOp)):
raise ValueError, ("Tracks are op classes or instances", f.__module__, f.__name__) raise ValueError("Tracks are op classes or instances", f.__module__, f.__name__)
requirements = () requirements = ()
if inplace: if inplace:
dh_handler = dh.DestroyHandler dh_handler = dh.DestroyHandler
......
from __future__ import print_function from __future__ import print_function
from theano.compat.six import reraise
from theano import gof from theano import gof
import sys import sys
...@@ -81,7 +82,7 @@ class DebugLinker(gof.WrapLinker): ...@@ -81,7 +82,7 @@ class DebugLinker(gof.WrapLinker):
exc.node = node exc.node = node
exc.thunk = thunk exc.thunk = thunk
exc.linker = linker exc.linker = linker
raise DebugException, exc, exc_trace reraise(DebugException, exc, exc_trace)
def compare_variables(self, i, node, *thunks): def compare_variables(self, i, node, *thunks):
thunk0 = thunks[0] thunk0 = thunks[0]
...@@ -146,7 +147,7 @@ class DebugLinker(gof.WrapLinker): ...@@ -146,7 +147,7 @@ class DebugLinker(gof.WrapLinker):
exc.step = i exc.step = i
exc.node = node exc.node = node
exc.thunks = thunks exc.thunks = thunks
raise DebugException, exc, exc_trace reraise(DebugException, exc, exc_trace)
def print_info(i, node, *thunks): def print_info(i, node, *thunks):
......
...@@ -73,7 +73,7 @@ class Kernel(object): ...@@ -73,7 +73,7 @@ class Kernel(object):
elif isinstance(t, Variable): elif isinstance(t, Variable):
return t.type.dtype return t.type.dtype
else: else:
raise TypeError, "can't get a dtype from %s" % (type(t),) raise TypeError("can't get a dtype from %s" % (type(t),))
dtypes = [get_dtype(t) for t in types] dtypes = [get_dtype(t) for t in types]
flags = dict(cluda=True) flags = dict(cluda=True)
if any(d == numpy.float64 for d in dtypes): if any(d == numpy.float64 for d in dtypes):
...@@ -116,7 +116,7 @@ class GpuKernelBase(object): ...@@ -116,7 +116,7 @@ class GpuKernelBase(object):
iterable of Kernel objects that describe the kernels this op iterable of Kernel objects that describe the kernels this op
will need. will need.
""" """
raise MethodNotDefined, 'gpu_kernels' raise MethodNotDefined('gpu_kernels')
def c_headers(self): def c_headers(self):
try: try:
......
...@@ -40,9 +40,9 @@ class NVCC_compiler(NVCC_base): ...@@ -40,9 +40,9 @@ class NVCC_compiler(NVCC_base):
if not any(['-arch=sm_' in f for f in flags]): if not any(['-arch=sm_' in f for f in flags]):
dev = theano.sandbox.gpuarray.init_dev.device dev = theano.sandbox.gpuarray.init_dev.device
if dev is None: if dev is None:
raise Exception, "Trying to compile GPU code without a context" raise Exception("Trying to compile GPU code without a context")
if dev.startswith("opencl"): if dev.startswith("opencl"):
raise Exception, "Trying to call nvcc with an OpenCL context" raise Exception("Trying to call nvcc with an OpenCL context")
assert dev.startswith('cuda') assert dev.startswith('cuda')
if dev == 'cuda': if dev == 'cuda':
n = theano.sandbox.cuda.use.device_number n = theano.sandbox.cuda.use.device_number
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论