提交 9733a2ab authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix exception handling in as_tensor_variable

This change prevents downstream code (e.g. test value computation) from throwing exceptions (e.g. `ValueError`s) that break the logic in `theano.tensor.basic.as_tensor_variable`.
上级 ba126742
......@@ -2756,18 +2756,18 @@ class ApplyDefaultTestOp(theano.Op):
def test_constant():
int8_type = tensor.TensorType(dtype="int8", broadcastable=(False,))
int8_vector_type = tensor.TensorType(dtype="int8", broadcastable=(False,))
# Make sure we return a `TensorConstant` unchanged
x = tensor.TensorConstant(int8_type, [1, 2])
x = tensor.TensorConstant(int8_vector_type, [1, 2])
y = constant(x)
assert y is x
# Make sure we can add and remove broadcastable dimensions
int8_type = tensor.TensorType(dtype="int8", broadcastable=())
int8_scalar_type = tensor.TensorType(dtype="int8", broadcastable=())
x_data = np.array(2, dtype="int8")
x = tensor.TensorConstant(int8_type, x_data)
x = tensor.TensorConstant(int8_scalar_type, x_data)
y = constant(x, ndim=1)
assert y.ndim == 1
assert np.array_equal(y.data, np.expand_dims(x_data, 0))
......@@ -2794,21 +2794,27 @@ class TestAsTensorVariable:
y = as_tensor_variable(scal.int8())
assert isinstance(y.owner.op, TensorFromScalar)
def test_one_output(self):
def test_multi_outputs(self):
good_apply_var = ApplyDefaultTestOp(0).make_node(self.x)
as_tensor_variable(good_apply_var)
def test_below_zero_output(self):
bad_apply_var = ApplyDefaultTestOp(-1).make_node(self.x)
with pytest.raises(AttributeError):
as_tensor_variable(bad_apply_var)
_ = as_tensor_variable(bad_apply_var)
def test_above_output_len(self):
bad_apply_var = ApplyDefaultTestOp(2).make_node(self.x)
with pytest.raises(AttributeError):
as_tensor_variable(bad_apply_var)
_ = as_tensor_variable(bad_apply_var)
def test_list(self):
# Make sure our exception handling during `Sequence` processing doesn't
# mask exceptions caused by unrelated logic (e.g. computing test
# values)
with change_flags(compute_test_value="raise"), pytest.raises(ValueError):
a = tensor.lscalar("a")
y = (a, a, 1)
_ = as_tensor_variable(y)
bad_apply_var = ApplyDefaultTestOp([0, 1]).make_node(self.x)
with pytest.raises(AttributeError):
as_tensor_variable(bad_apply_var)
......
......@@ -10,27 +10,20 @@ If you want to use a scalar variable in a Theano graph,
you probably want to use theano.tensor.[c,z,f,d,b,w,i,l,]scalar!
"""
from itertools import chain
import math
import warnings
from copy import copy
from functools import partial
from itertools import chain
from textwrap import dedent
import numpy as np
import six
import theano
from theano import config, gof, printing
from theano.compat import Callable
from theano import gof, printing
from theano.gof import Op, utils, Variable, Constant, Type, Apply, FunctionGraph
from functools import partial
from theano import config
from theano.gradient import DisconnectedType
from theano.gradient import grad_undefined
from theano.gof import Apply, Constant, FunctionGraph, Op, Type, Variable, utils
from theano.gradient import DisconnectedType, grad_undefined
from theano.printing import pprint
builtin_bool = bool
......@@ -267,9 +260,14 @@ class autocast_float_as(object):
def convert(x, dtype=None):
"""
Convert the input to a properly typed numpy value according to the
current casting policy. Work with scalars and tensors.
"""Convert the input to a properly typed NumPy value according to the current casting policy.
Parameters
----------
x : Number, numpy.ndarray, or Sequence[Number]
The value(s) to be converted
dtype : str or numpy.dtype (optional)
The dtype to use for the conversion of `x`.
"""
if dtype is not None:
......
......@@ -9,7 +9,6 @@ from theano.tensor.basic import *
from theano.tensor.subtensor import *
from theano.tensor.type_other import *
from theano.tensor.var import (
AsTensorError,
_tensor_py_operators,
TensorVariable,
TensorConstantSignature,
......
......@@ -22,7 +22,6 @@ from theano.gof.type import Generic
from theano.scalar import int32
from theano.tensor import elemwise
from theano.tensor.var import (
AsTensorError,
TensorVariable,
TensorConstant,
_tensor_py_operators,
......@@ -132,10 +131,7 @@ def as_tensor_variable(x, name=None, ndim=None):
Raises
------
ValueError
If an `Apply` with more than one output is fetched or
if `x` cannot be made into a Variable with `ndim` dimensions.
AsTensorError
TypeError
If `x` cannot be converted to a TensorType Variable.
"""
......@@ -152,7 +148,7 @@ def as_tensor_variable(x, name=None, ndim=None):
if isinstance(x, gof.Apply):
# use Apply's default output mechanism
if (x.op.default_output is None) and (len(x.outputs) != 1):
raise ValueError(
raise TypeError(
"Multi-output Op encountered. "
"Retry using only one of the outputs directly."
)
......@@ -168,7 +164,7 @@ def as_tensor_variable(x, name=None, ndim=None):
x = tensor_from_scalar(x)
if not isinstance(x.type, TensorType):
raise AsTensorError(
raise TypeError(
"Tensor type field must be a TensorType; found {}.".format(type(x.type))
)
......@@ -207,13 +203,10 @@ def as_tensor_variable(x, name=None, ndim=None):
try:
x = [extract_constants(i) for i in x]
except TypeError:
try:
return stack(x)
except (TypeError, ValueError):
pass
return stack(x)
elif isinstance(x, bool):
raise AsTensorError(
raise TypeError(
"Cannot cast True or False as a tensor variable. Please use "
"np.array(True) or np.array(False) if you need these constants. "
"This error might be caused by using the == operator on "
......@@ -221,12 +214,7 @@ def as_tensor_variable(x, name=None, ndim=None):
"use theano.tensor.eq(v, w) instead."
)
try:
return constant(x, name=name, ndim=ndim)
except TypeError:
raise AsTensorError(
"Cannot convert {} of type {} to TensorType".format(x, type(x))
)
return constant(x, name=name, ndim=ndim)
# this has a different name, because _as_tensor_variable is the
......
......@@ -16,15 +16,6 @@ from theano.tensor.type import TensorType
from theano import config
class AsTensorError(TypeError):
"""
Raised when as_tensor_variable isn't able to create a TensorVariable.
"""
pass
class _tensor_py_operators(object):
def __abs__(self):
return theano.tensor.basic.abs_(self)
......@@ -117,7 +108,7 @@ class _tensor_py_operators(object):
# Evidently, we need to catch NotImplementedError
# TypeError from as_tensor_variable are caught in Elemwise.make_node
# Oterwise TensorVariable * SparseVariable won't work!
except (NotImplementedError, AsTensorError):
except (NotImplementedError, TypeError):
# We must return NotImplemented and not an
# NotImplementedError or raise an NotImplementedError.
# That way python will give a good error message like this
......@@ -130,7 +121,7 @@ class _tensor_py_operators(object):
# and the return value in that case
try:
return theano.tensor.basic.sub(self, other)
except (NotImplementedError, AsTensorError):
except (NotImplementedError, TypeError):
return NotImplemented
def __mul__(self, other):
......@@ -138,7 +129,7 @@ class _tensor_py_operators(object):
# and the return value in that case
try:
return theano.tensor.mul(self, other)
except (NotImplementedError, AsTensorError):
except (NotImplementedError, TypeError):
return NotImplemented
def __div__(self, other):
......@@ -150,7 +141,7 @@ class _tensor_py_operators(object):
# This is to raise the exception that occurs when trying to divide
# two integer arrays (currently forbidden).
raise
except (NotImplementedError, AsTensorError):
except (NotImplementedError, TypeError):
return NotImplemented
__truediv__ = __div__
......@@ -160,7 +151,7 @@ class _tensor_py_operators(object):
# and the return value in that case
try:
return theano.tensor.basic.pow(self, other)
except (NotImplementedError, AsTensorError):
except (NotImplementedError, TypeError):
return NotImplemented
def __mod__(self, other):
......@@ -172,7 +163,7 @@ class _tensor_py_operators(object):
# This is to raise the exception that occurs when trying to compute
# x % y with either x or y a complex number.
raise
except (NotImplementedError, AsTensorError):
except (NotImplementedError, TypeError):
return NotImplemented
def __divmod__(self, other):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论