提交 a1938f76 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Prevent theano.dot from catching all exceptions

上级 174cea8f
...@@ -172,28 +172,26 @@ else: ...@@ -172,28 +172,26 @@ else:
np.seterr(all=_all, divide=_divide, over=_over, under=_under, invalid=_invalid) np.seterr(all=_all, divide=_divide, over=_over, under=_under, invalid=_invalid)
del _all, _divide, _over, _under, _invalid del _all, _divide, _over, _under, _invalid
# This is defined here because it is designed to work across symbolic
# datatypes (Sparse and Tensor)
def dot(l, r): def dot(l, r):
"""Return a symbolic matrix/dot product between l and r """ """Return a symbolic dot product.
rval = NotImplemented
e0, e1 = None, None
if rval == NotImplemented and hasattr(l, "__dot__"): This is designed to work with both sparse and dense tensors types.
try: """
rval = l.__dot__(r)
except Exception as e0:
rval = NotImplemented
if rval == NotImplemented and hasattr(r, "__rdot__"):
try: try:
rval = r.__rdot__(l) res = l.__dot__(r)
except Exception as e1:
rval = NotImplemented if res is NotImplemented:
if rval == NotImplemented: raise NotImplementedError()
raise NotImplementedError("Dot failed for the following reasons:", (e0, e1))
return rval return res
except (NotImplementedError, AttributeError, TypeError):
res = r.__rdot__(l)
if res is NotImplemented:
raise NotImplementedError()
return res
def get_scalar_constant_value(v): def get_scalar_constant_value(v):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论