提交 9cbd604e authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Change the exception type when a context is not defined.

上级 d2de7531
...@@ -17,7 +17,8 @@ from theano.scan_module import scan_utils, scan_op, scan_opt ...@@ -17,7 +17,8 @@ from theano.scan_module import scan_utils, scan_op, scan_opt
from theano.tensor.nnet.conv import ConvOp from theano.tensor.nnet.conv import ConvOp
from theano.tests.breakpoint import PdbBreakpoint from theano.tests.breakpoint import PdbBreakpoint
from .type import GpuArrayType, GpuArrayConstant, get_context from .type import (GpuArrayType, GpuArrayConstant, get_context,
ContextNotDefined)
from .basic_ops import (as_gpuarray_variable, infer_context_name, from .basic_ops import (as_gpuarray_variable, infer_context_name,
host_from_gpu, GpuToGpu, host_from_gpu, GpuToGpu,
HostFromGpu, GpuFromHost, HostFromGpu, GpuFromHost,
...@@ -177,7 +178,7 @@ class InputToGpuOptimizer(Optimizer): ...@@ -177,7 +178,7 @@ class InputToGpuOptimizer(Optimizer):
except TypeError: except TypeError:
# This could fail if the inputs are not TensorTypes # This could fail if the inputs are not TensorTypes
pass pass
except ValueError: except ContextNotDefined:
# If there is no context tag and no default context # If there is no context tag and no default context
# then it stays on the CPU # then it stays on the CPU
if not hasattr(input.tag, 'context_name'): if not hasattr(input.tag, 'context_name'):
...@@ -255,7 +256,7 @@ def local_gpuaalloc2(node): ...@@ -255,7 +256,7 @@ def local_gpuaalloc2(node):
""" """
try: try:
get_context(None) get_context(None)
except ValueError: except ContextNotDefined:
# If there is no default context then we do not perform the move here. # If there is no default context then we do not perform the move here.
return return
if (isinstance(node.op, tensor.Alloc) and if (isinstance(node.op, tensor.Alloc) and
......
...@@ -17,6 +17,10 @@ except ImportError: ...@@ -17,6 +17,10 @@ except ImportError:
_context_reg = {} _context_reg = {}
class ContextNotDefined(ValueError):
pass
def reg_context(name, ctx): def reg_context(name, ctx):
""" """
Register a context by mapping it to a name. Register a context by mapping it to a name.
...@@ -56,7 +60,7 @@ def get_context(name): ...@@ -56,7 +60,7 @@ def get_context(name):
""" """
if name not in _context_reg: if name not in _context_reg:
raise ValueError("context name %s not defined" % (name,)) raise ContextNotDefined("context name %s not defined" % (name,))
return _context_reg[name] return _context_reg[name]
...@@ -72,7 +76,7 @@ def _name_for_ctx(ctx): ...@@ -72,7 +76,7 @@ def _name_for_ctx(ctx):
for k, v in _context_reg: for k, v in _context_reg:
if v == ctx: if v == ctx:
return k return k
raise ValueError('context is not registered') raise ContextNotDefined('context is not registered')
# This is a private method for use by the tests only # This is a private method for use by the tests only
...@@ -479,7 +483,7 @@ def gpuarray_shared_constructor(value, name=None, strict=False, ...@@ -479,7 +483,7 @@ def gpuarray_shared_constructor(value, name=None, strict=False,
try: try:
get_context(context_name) get_context(context_name)
except ValueError: except ContextNotDefined:
# Don't make this a hard error if we attempt to make a shared # Don't make this a hard error if we attempt to make a shared
# variable while there is no default context. # variable while there is no default context.
if context_name is None: if context_name is None:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论