提交 a1938f76 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Prevent theano.dot from catching all exceptions

上级 174cea8f
...@@ -172,28 +172,26 @@ else: ...@@ -172,28 +172,26 @@ else:
np.seterr(all=_all, divide=_divide, over=_over, under=_under, invalid=_invalid) np.seterr(all=_all, divide=_divide, over=_over, under=_under, invalid=_invalid)
del _all, _divide, _over, _under, _invalid del _all, _divide, _over, _under, _invalid
# This is defined here because it is designed to work across symbolic
# datatypes (Sparse and Tensor)
def dot(l, r): def dot(l, r):
"""Return a symbolic matrix/dot product between l and r """ """Return a symbolic dot product.
rval = NotImplemented
e0, e1 = None, None This is designed to work with both sparse and dense tensors types.
"""
if rval == NotImplemented and hasattr(l, "__dot__"): try:
try: res = l.__dot__(r)
rval = l.__dot__(r)
except Exception as e0: if res is NotImplemented:
rval = NotImplemented raise NotImplementedError()
if rval == NotImplemented and hasattr(r, "__rdot__"):
try: return res
rval = r.__rdot__(l) except (NotImplementedError, AttributeError, TypeError):
except Exception as e1: res = r.__rdot__(l)
rval = NotImplemented
if rval == NotImplemented: if res is NotImplemented:
raise NotImplementedError("Dot failed for the following reasons:", (e0, e1)) raise NotImplementedError()
return rval
return res
def get_scalar_constant_value(v): def get_scalar_constant_value(v):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论