提交 a1938f76 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Prevent theano.dot from catching all exceptions

上级 174cea8f
......@@ -172,28 +172,26 @@ else:
np.seterr(all=_all, divide=_divide, over=_over, under=_under, invalid=_invalid)
del _all, _divide, _over, _under, _invalid
# This is defined here because it is designed to work across symbolic
# datatypes (Sparse and Tensor)
def dot(l, r):
"""Return a symbolic matrix/dot product between l and r """
rval = NotImplemented
e0, e1 = None, None
if rval == NotImplemented and hasattr(l, "__dot__"):
try:
rval = l.__dot__(r)
except Exception as e0:
rval = NotImplemented
if rval == NotImplemented and hasattr(r, "__rdot__"):
try:
rval = r.__rdot__(l)
except Exception as e1:
rval = NotImplemented
if rval == NotImplemented:
raise NotImplementedError("Dot failed for the following reasons:", (e0, e1))
return rval
"""Return a symbolic dot product.
This is designed to work with both sparse and dense tensors types.
"""
try:
res = l.__dot__(r)
if res is NotImplemented:
raise NotImplementedError()
return res
except (NotImplementedError, AttributeError, TypeError):
res = r.__rdot__(l)
if res is NotImplemented:
raise NotImplementedError()
return res
def get_scalar_constant_value(v):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论