提交 83d55369 authored 作者: Frederic's avatar Frederic

Fixed an optimization crash during gemm optimization related to complex

上级 9d186f9f
......@@ -1120,8 +1120,17 @@ def _gemm_from_factored_list(lst):
return True
except Exception:
return False
lst = [(T.cast(sM[0],sM[1].type.dtype), sM[1])
for sM in lst if is_pair(sM)]
lst2 = []
# Remove the tuple that can't be casted correctly.
# This can happen when we try to cast a complex to a real
for sM in lst:
if is_pair(sM):
try:
lst2.append(T.cast(sM[0],sM[1].type.dtype), sM[1])
except TypeError:
pass
lst = lst2
# Try every pair in the sM_list, trying to turn it into a gemm operation
for i in xrange(len(lst) - 1):
......
......@@ -577,6 +577,17 @@ def test_upcasting_scalar_nogemm():
assert numpy.sum([isinstance(n.op, Gemm) for n in t]) == 0
theano.printing.debugprint(f, print_type=True)
v = T.fmatrix('v')
w = T.fmatrix('w')
t = T.fmatrix('t')
alpha = T.cscalar('a')
rval = T.dot(w, v) * alpha + t
f = theano.function([w, v, t, alpha], rval)
t = f.maker.env.toposort()
assert numpy.sum([isinstance(n.op, Gemm) for n in t]) == 0
theano.printing.debugprint(f, print_type=True)
def test_gemm_nested():
X,Y,Z,a,b = T.dmatrix('X'), T.dmatrix('Y'), T.dmatrix('Z'), T.dscalar('a'), T.dscalar('b')
R,S,U,c,d = T.dmatrix('R'), T.dmatrix('S'), T.dmatrix('U'), T.dscalar('c'), T.dscalar('d')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论