提交 9733a2ab authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix exception handling in as_tensor_variable

This change prevents downstream code (e.g. test value computation) from throwing exceptions (e.g. `ValueError`s) that break the logic in `theano.tensor.basic.as_tensor_variable`.
上级 ba126742
...@@ -2756,18 +2756,18 @@ class ApplyDefaultTestOp(theano.Op): ...@@ -2756,18 +2756,18 @@ class ApplyDefaultTestOp(theano.Op):
def test_constant(): def test_constant():
int8_type = tensor.TensorType(dtype="int8", broadcastable=(False,)) int8_vector_type = tensor.TensorType(dtype="int8", broadcastable=(False,))
# Make sure we return a `TensorConstant` unchanged # Make sure we return a `TensorConstant` unchanged
x = tensor.TensorConstant(int8_type, [1, 2]) x = tensor.TensorConstant(int8_vector_type, [1, 2])
y = constant(x) y = constant(x)
assert y is x assert y is x
# Make sure we can add and remove broadcastable dimensions # Make sure we can add and remove broadcastable dimensions
int8_type = tensor.TensorType(dtype="int8", broadcastable=()) int8_scalar_type = tensor.TensorType(dtype="int8", broadcastable=())
x_data = np.array(2, dtype="int8") x_data = np.array(2, dtype="int8")
x = tensor.TensorConstant(int8_type, x_data) x = tensor.TensorConstant(int8_scalar_type, x_data)
y = constant(x, ndim=1) y = constant(x, ndim=1)
assert y.ndim == 1 assert y.ndim == 1
assert np.array_equal(y.data, np.expand_dims(x_data, 0)) assert np.array_equal(y.data, np.expand_dims(x_data, 0))
...@@ -2794,21 +2794,27 @@ class TestAsTensorVariable: ...@@ -2794,21 +2794,27 @@ class TestAsTensorVariable:
y = as_tensor_variable(scal.int8()) y = as_tensor_variable(scal.int8())
assert isinstance(y.owner.op, TensorFromScalar) assert isinstance(y.owner.op, TensorFromScalar)
def test_one_output(self): def test_multi_outputs(self):
good_apply_var = ApplyDefaultTestOp(0).make_node(self.x) good_apply_var = ApplyDefaultTestOp(0).make_node(self.x)
as_tensor_variable(good_apply_var) as_tensor_variable(good_apply_var)
def test_below_zero_output(self):
bad_apply_var = ApplyDefaultTestOp(-1).make_node(self.x) bad_apply_var = ApplyDefaultTestOp(-1).make_node(self.x)
with pytest.raises(AttributeError): with pytest.raises(AttributeError):
as_tensor_variable(bad_apply_var) _ = as_tensor_variable(bad_apply_var)
def test_above_output_len(self):
bad_apply_var = ApplyDefaultTestOp(2).make_node(self.x) bad_apply_var = ApplyDefaultTestOp(2).make_node(self.x)
with pytest.raises(AttributeError): with pytest.raises(AttributeError):
as_tensor_variable(bad_apply_var) _ = as_tensor_variable(bad_apply_var)
def test_list(self): def test_list(self):
# Make sure our exception handling during `Sequence` processing doesn't
# mask exceptions caused by unrelated logic (e.g. computing test
# values)
with change_flags(compute_test_value="raise"), pytest.raises(ValueError):
a = tensor.lscalar("a")
y = (a, a, 1)
_ = as_tensor_variable(y)
bad_apply_var = ApplyDefaultTestOp([0, 1]).make_node(self.x) bad_apply_var = ApplyDefaultTestOp([0, 1]).make_node(self.x)
with pytest.raises(AttributeError): with pytest.raises(AttributeError):
as_tensor_variable(bad_apply_var) as_tensor_variable(bad_apply_var)
......
...@@ -10,27 +10,20 @@ If you want to use a scalar variable in a Theano graph, ...@@ -10,27 +10,20 @@ If you want to use a scalar variable in a Theano graph,
you probably want to use theano.tensor.[c,z,f,d,b,w,i,l,]scalar! you probably want to use theano.tensor.[c,z,f,d,b,w,i,l,]scalar!
""" """
from itertools import chain
import math import math
import warnings import warnings
from copy import copy from copy import copy
from functools import partial
from itertools import chain
from textwrap import dedent from textwrap import dedent
import numpy as np import numpy as np
import six import six
import theano import theano
from theano import config, gof, printing
from theano.compat import Callable from theano.compat import Callable
from theano import gof, printing from theano.gof import Apply, Constant, FunctionGraph, Op, Type, Variable, utils
from theano.gof import Op, utils, Variable, Constant, Type, Apply, FunctionGraph from theano.gradient import DisconnectedType, grad_undefined
from functools import partial
from theano import config
from theano.gradient import DisconnectedType
from theano.gradient import grad_undefined
from theano.printing import pprint from theano.printing import pprint
builtin_bool = bool builtin_bool = bool
...@@ -267,9 +260,14 @@ class autocast_float_as(object): ...@@ -267,9 +260,14 @@ class autocast_float_as(object):
def convert(x, dtype=None): def convert(x, dtype=None):
""" """Convert the input to a properly typed NumPy value according to the current casting policy.
Convert the input to a properly typed numpy value according to the
current casting policy. Work with scalars and tensors. Parameters
----------
x : Number, numpy.ndarray, or Sequence[Number]
The value(s) to be converted
dtype : str or numpy.dtype (optional)
The dtype to use for the conversion of `x`.
""" """
if dtype is not None: if dtype is not None:
......
...@@ -9,7 +9,6 @@ from theano.tensor.basic import * ...@@ -9,7 +9,6 @@ from theano.tensor.basic import *
from theano.tensor.subtensor import * from theano.tensor.subtensor import *
from theano.tensor.type_other import * from theano.tensor.type_other import *
from theano.tensor.var import ( from theano.tensor.var import (
AsTensorError,
_tensor_py_operators, _tensor_py_operators,
TensorVariable, TensorVariable,
TensorConstantSignature, TensorConstantSignature,
......
...@@ -22,7 +22,6 @@ from theano.gof.type import Generic ...@@ -22,7 +22,6 @@ from theano.gof.type import Generic
from theano.scalar import int32 from theano.scalar import int32
from theano.tensor import elemwise from theano.tensor import elemwise
from theano.tensor.var import ( from theano.tensor.var import (
AsTensorError,
TensorVariable, TensorVariable,
TensorConstant, TensorConstant,
_tensor_py_operators, _tensor_py_operators,
...@@ -132,10 +131,7 @@ def as_tensor_variable(x, name=None, ndim=None): ...@@ -132,10 +131,7 @@ def as_tensor_variable(x, name=None, ndim=None):
Raises Raises
------ ------
ValueError TypeError
If an `Apply` with more than one output is fetched or
if `x` cannot be made into a Variable with `ndim` dimensions.
AsTensorError
If `x` cannot be converted to a TensorType Variable. If `x` cannot be converted to a TensorType Variable.
""" """
...@@ -152,7 +148,7 @@ def as_tensor_variable(x, name=None, ndim=None): ...@@ -152,7 +148,7 @@ def as_tensor_variable(x, name=None, ndim=None):
if isinstance(x, gof.Apply): if isinstance(x, gof.Apply):
# use Apply's default output mechanism # use Apply's default output mechanism
if (x.op.default_output is None) and (len(x.outputs) != 1): if (x.op.default_output is None) and (len(x.outputs) != 1):
raise ValueError( raise TypeError(
"Multi-output Op encountered. " "Multi-output Op encountered. "
"Retry using only one of the outputs directly." "Retry using only one of the outputs directly."
) )
...@@ -168,7 +164,7 @@ def as_tensor_variable(x, name=None, ndim=None): ...@@ -168,7 +164,7 @@ def as_tensor_variable(x, name=None, ndim=None):
x = tensor_from_scalar(x) x = tensor_from_scalar(x)
if not isinstance(x.type, TensorType): if not isinstance(x.type, TensorType):
raise AsTensorError( raise TypeError(
"Tensor type field must be a TensorType; found {}.".format(type(x.type)) "Tensor type field must be a TensorType; found {}.".format(type(x.type))
) )
...@@ -207,13 +203,10 @@ def as_tensor_variable(x, name=None, ndim=None): ...@@ -207,13 +203,10 @@ def as_tensor_variable(x, name=None, ndim=None):
try: try:
x = [extract_constants(i) for i in x] x = [extract_constants(i) for i in x]
except TypeError: except TypeError:
try: return stack(x)
return stack(x)
except (TypeError, ValueError):
pass
elif isinstance(x, bool): elif isinstance(x, bool):
raise AsTensorError( raise TypeError(
"Cannot cast True or False as a tensor variable. Please use " "Cannot cast True or False as a tensor variable. Please use "
"np.array(True) or np.array(False) if you need these constants. " "np.array(True) or np.array(False) if you need these constants. "
"This error might be caused by using the == operator on " "This error might be caused by using the == operator on "
...@@ -221,12 +214,7 @@ def as_tensor_variable(x, name=None, ndim=None): ...@@ -221,12 +214,7 @@ def as_tensor_variable(x, name=None, ndim=None):
"use theano.tensor.eq(v, w) instead." "use theano.tensor.eq(v, w) instead."
) )
try: return constant(x, name=name, ndim=ndim)
return constant(x, name=name, ndim=ndim)
except TypeError:
raise AsTensorError(
"Cannot convert {} of type {} to TensorType".format(x, type(x))
)
# this has a different name, because _as_tensor_variable is the # this has a different name, because _as_tensor_variable is the
......
...@@ -16,15 +16,6 @@ from theano.tensor.type import TensorType ...@@ -16,15 +16,6 @@ from theano.tensor.type import TensorType
from theano import config from theano import config
class AsTensorError(TypeError):
"""
Raised when as_tensor_variable isn't able to create a TensorVariable.
"""
pass
class _tensor_py_operators(object): class _tensor_py_operators(object):
def __abs__(self): def __abs__(self):
return theano.tensor.basic.abs_(self) return theano.tensor.basic.abs_(self)
...@@ -117,7 +108,7 @@ class _tensor_py_operators(object): ...@@ -117,7 +108,7 @@ class _tensor_py_operators(object):
# Evidently, we need to catch NotImplementedError # Evidently, we need to catch NotImplementedError
# TypeError from as_tensor_variable are caught in Elemwise.make_node # TypeError from as_tensor_variable are caught in Elemwise.make_node
# Oterwise TensorVariable * SparseVariable won't work! # Oterwise TensorVariable * SparseVariable won't work!
except (NotImplementedError, AsTensorError): except (NotImplementedError, TypeError):
# We must return NotImplemented and not an # We must return NotImplemented and not an
# NotImplementedError or raise an NotImplementedError. # NotImplementedError or raise an NotImplementedError.
# That way python will give a good error message like this # That way python will give a good error message like this
...@@ -130,7 +121,7 @@ class _tensor_py_operators(object): ...@@ -130,7 +121,7 @@ class _tensor_py_operators(object):
# and the return value in that case # and the return value in that case
try: try:
return theano.tensor.basic.sub(self, other) return theano.tensor.basic.sub(self, other)
except (NotImplementedError, AsTensorError): except (NotImplementedError, TypeError):
return NotImplemented return NotImplemented
def __mul__(self, other): def __mul__(self, other):
...@@ -138,7 +129,7 @@ class _tensor_py_operators(object): ...@@ -138,7 +129,7 @@ class _tensor_py_operators(object):
# and the return value in that case # and the return value in that case
try: try:
return theano.tensor.mul(self, other) return theano.tensor.mul(self, other)
except (NotImplementedError, AsTensorError): except (NotImplementedError, TypeError):
return NotImplemented return NotImplemented
def __div__(self, other): def __div__(self, other):
...@@ -150,7 +141,7 @@ class _tensor_py_operators(object): ...@@ -150,7 +141,7 @@ class _tensor_py_operators(object):
# This is to raise the exception that occurs when trying to divide # This is to raise the exception that occurs when trying to divide
# two integer arrays (currently forbidden). # two integer arrays (currently forbidden).
raise raise
except (NotImplementedError, AsTensorError): except (NotImplementedError, TypeError):
return NotImplemented return NotImplemented
__truediv__ = __div__ __truediv__ = __div__
...@@ -160,7 +151,7 @@ class _tensor_py_operators(object): ...@@ -160,7 +151,7 @@ class _tensor_py_operators(object):
# and the return value in that case # and the return value in that case
try: try:
return theano.tensor.basic.pow(self, other) return theano.tensor.basic.pow(self, other)
except (NotImplementedError, AsTensorError): except (NotImplementedError, TypeError):
return NotImplemented return NotImplemented
def __mod__(self, other): def __mod__(self, other):
...@@ -172,7 +163,7 @@ class _tensor_py_operators(object): ...@@ -172,7 +163,7 @@ class _tensor_py_operators(object):
# This is to raise the exception that occurs when trying to compute # This is to raise the exception that occurs when trying to compute
# x % y with either x or y a complex number. # x % y with either x or y a complex number.
raise raise
except (NotImplementedError, AsTensorError): except (NotImplementedError, TypeError):
return NotImplemented return NotImplemented
def __divmod__(self, other): def __divmod__(self, other):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论