提交 29d37ffc authored 作者: Frederic's avatar Frederic

better fix to gh-1133. This do not change the interface to make_node in case of error.

上级 1074b68d
......@@ -69,6 +69,13 @@ class ShapeError(Exception):
pass
class AsTensorError(Exception):
"""Raised when as_tensor_variable isn't able to create a
TensorVariable.
"""
pass
def check_equal_numpy(x, y):
"""
Returns True iff x and y are equal (checks the dtype and
......@@ -151,8 +158,8 @@ def as_tensor_variable(x, name=None, ndim=None):
not possible.
:Exceptions:
- `ValueError`: raised if an `Apply` with no default output is fetched
- `TypeError`: raised if `x` cannot be converted to a TensorType Variable
- `ValueError`: raised if an `Apply` with more then one output is fetched
- `AsTensorError`: raised if `x` cannot be converted to a TensorType Variable
"""
if hasattr(x, '_as_TensorVariable'):
......@@ -171,7 +178,7 @@ def as_tensor_variable(x, name=None, ndim=None):
x = tensor_from_scalar(x)
if not isinstance(x.type, TensorType):
raise TypeError(
raise AsTensorError(
"Variable type field must be a TensorType.", x, x.type)
if ndim is None:
......@@ -194,7 +201,7 @@ def as_tensor_variable(x, name=None, ndim=None):
pass
if isinstance(x, bool):
raise TypeError(
raise AsTensorError(
"Cannot cast True or False as a tensor variable. Please use 1 or "
"0. This error might be caused by using the == operator on "
"Variables. v == w does not do what you think it does, "
......@@ -207,7 +214,7 @@ def as_tensor_variable(x, name=None, ndim=None):
str_x = str(x)
except Exception:
str_x = repr(x)
raise TypeError("Cannot convert %s to TensorType" % str_x, type(x))
raise AsTensorError("Cannot convert %s to TensorType" % str_x, type(x))
# this has a different name, because _as_tensor_variable is the
# function which ops use to upcast their arguments... this
......@@ -1586,7 +1593,7 @@ class _tensor_py_operators:
# Evidently, we need to catch NotImplementedError
# TypeError from as_tensor_variable are caught in Elemwise.make_node
# Oterwise TensorVariable * SparseVariable won't work!
except NotImplementedError:
except (NotImplementedError, AsTensorError):
# We must return NotImplemented and not an
# NotImplementedError or raise an NotImplementedError.
# That way python will give a good error message like this
......@@ -1599,7 +1606,7 @@ class _tensor_py_operators:
# and the return value in that case
try:
return sub(self, other)
except NotImplementedError:
except (NotImplementedError, AsTensorError):
return NotImplemented
def __mul__(self, other):
......@@ -1607,7 +1614,7 @@ class _tensor_py_operators:
# and the return value in that case
try:
return mul(self, other)
except NotImplementedError:
except (NotImplementedError, AsTensorError):
return NotImplemented
def __div__(self, other):
......@@ -1619,7 +1626,7 @@ class _tensor_py_operators:
# This is to raise the exception that occurs when trying to divide
# two integer arrays (currently forbidden).
raise
except NotImplementedError:
except (NotImplementedError, AsTensorError):
return NotImplemented
if PY3:
__truediv__ = __div__
......@@ -1629,7 +1636,7 @@ class _tensor_py_operators:
# adn the return value in that case
try:
return pow(self, other)
except NotImplementedError:
except (NotImplementedError, AsTensorError):
return NotImplemented
def __mod__(self, other):
......@@ -1641,7 +1648,7 @@ class _tensor_py_operators:
# This is to raise the exception that occurs when trying to compute
# x % y with either x or y a complex number.
raise
except NotImplementedError:
except (NotImplementedError, AsTensorError):
return NotImplemented
def __truediv__(self, other):
......
......@@ -530,13 +530,7 @@ class Elemwise(Op):
is left-completed to the greatest number of dimensions with 1s
using DimShuffle.
"""
try:
inputs = map(as_tensor_variable, inputs)
except TypeError:
# __{add,sub,mul,div,mod,pow}__
# need to return NotImplemented to make
# TensorVariable op SparseVariable work.
return NotImplemented
inputs = map(as_tensor_variable, inputs)
shadow = self.scalar_op.make_node(
*[Scalar(dtype=i.type.dtype)() for i in inputs])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论