提交 0a62b7bc authored 作者: James Bergstra's avatar James Bergstra

fixed GemmOptimizer bug

上级 ec7877c4
......@@ -403,17 +403,29 @@ class GemmLocalOptimizer(LocalOptimizer):
rval = beta_L_plus_alpha_M(sL, mL, -sR, mR)
return rval
if node.op == T.add:
# arguments of the form scalar * matrix
sM_list = []
# arguments that can be interpreted as scalar * matrix
sM_orig = []
# arguments not of the form scalar * matrix (i.e., vectors, scalars)
other_inputs = []
for input in node.inputs:
tmp = _as_isolated_scalar_times_matrix(input)
if tmp:
sM_list.append(tmp)
sM_orig.append(input)
elif _is_real_matrix(input):
sM_list.append((1.0, input))
sM_orig.append(input)
else:
other_inputs.append(input)
assert len(sM_list) == len(sM_orig)
assert len(sM_list) + len(other_inputs) == len(node.inputs)
if len(sM_list) == 2:
(sL, mL), (sR, mR) = sM_list
gemm_of_sM_list = beta_L_plus_alpha_M(sL, mL, sR, mR)
......@@ -425,16 +437,34 @@ class GemmLocalOptimizer(LocalOptimizer):
else:
return gemm_of_sM_list
else:
# Try every pair in the sM_list, trying to turn it into a gemm operation
for i in xrange(len(sM_list) - 1):
for j in xrange(i+1, len(sM_list)):
assert i != j
sL, mL = sM_list[i]
sR, mR = sM_list[j]
gemm_of_sM_list = beta_L_plus_alpha_M(sL, mL, sR, mR)
if gemm_of_sM_list:
assert len(gemm_of_sM_list) == 1
inputs_without_ij = \
[input for k, input in enumerate(node.inputs) if k not in (i,j)]
return [T.add( *(inputs_without_ij + gemm_of_sM_list + other_inputs))]
[input for k, input in enumerate(sM_orig) if k not in (i,j)]
new_add_inputs = (inputs_without_ij + gemm_of_sM_list + other_inputs)
if False: #SUPER DEBUG MODE :(
if len(new_add_inputs) + 1 != len(node.inputs):
print 'inputs', node.inputs
print 'sM, other', sM_list, other_inputs
print 'i,j', i, j
print 'gemm', gemm_of_sM_list
print 'without ij', inputs_without_ij
print 'new inputs', new_add_inputs
sys.exit(1)
# this should be True because we've combined a pair of arguments
# into a single GEMM
assert len(new_add_inputs) + 1 == len(node.inputs)
return [T.add(*new_add_inputs)]
return False
@staticmethod
......@@ -443,7 +473,7 @@ class GemmLocalOptimizer(LocalOptimizer):
if not isinstance(exc, InconsistencyError):
traceback.print_exc()
else:
#print 'GEMM caused cycle, forget it.'
#print 'GEMM caused cycle, it happens.'
pass
@staticmethod
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论