提交 817d0cbd authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Do not return Gemm/Gemv candidates with the wrong dtype

上级 ad16bb64
......@@ -822,9 +822,14 @@ def _factor_canonicalized(lst):
# once i has touched a list element, it is permantent
lst = list(lst)
#print 'FACTOR', lst
#for (a,b) in lst:
#theano.printing.debugprint(a)
#theano.printing.debugprint(b)
#for t in lst:
# if not isinstance(t, (list, tuple)):
# t = (t,)
# for e in t:
# try:
# theano.printing.debugprint(e)
# except TypeError:
# print e, type(e)
i = 0
while i < len(lst)-1:
try:
......@@ -904,9 +909,8 @@ def _gemm_from_node2(node):
lst = _factor_canonicalized(lst)
rval = _gemm_from_factored_list(lst)
#print "RVAL", rval
if rval:
assert rval[0].type == node.outputs[0].type, (rval[0].type, node.outputs[0].type)
return rval
if rval and (rval[0].type == node.outputs[0].type):
return rval
class GemmOptimizer(Optimizer):
"""Graph optimizer for inserting Gemm operations"""
......
......@@ -479,6 +479,35 @@ def test_gemm_factor():
assert [(1.0, X), (1.0, Y)] == _factor_canonicalized([(1.0, X), (1.0, Y)])
assert [(2.0, X)] == _factor_canonicalized([(1.0, X),(1.0, X)])
def test_upcasting_scalar_nogemv():
# Test that the optimization does not crash when the scale has an incorrect
# dtype, and forces upcasting of the result
v = T.fvector('v')
w = T.fmatrix('w')
t = T.fvector('t')
alpha = T.dscalar('a')
rval = T.dot(w, v) * alpha + t
f = theano.function([w, v, t, alpha], rval)
t = f.maker.env.toposort()
assert numpy.sum([isinstance(n.op, Gemv) for n in t]) == 0
theano.printing.debugprint(f, print_type=True)
def test_upcasting_scalar_nogemm():
# Test that the optimization does not crash when the scale has an incorrect
# dtype, and forces upcasting of the result
v = T.fmatrix('v')
w = T.fmatrix('w')
t = T.fmatrix('t')
alpha = T.dscalar('a')
rval = T.dot(w, v) * alpha + t
f = theano.function([w, v, t, alpha], rval)
t = f.maker.env.toposort()
assert numpy.sum([isinstance(n.op, Gemm) for n in t]) == 0
theano.printing.debugprint(f, print_type=True)
def test_gemm_nested():
X,Y,Z,a,b = T.dmatrix('X'), T.dmatrix('Y'), T.dmatrix('Z'), T.dscalar('a'), T.dscalar('b')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论