提交 0a62b7bc authored 作者: James Bergstra's avatar James Bergstra

fixed GemmOptimizer bug

上级 ec7877c4
...@@ -403,17 +403,29 @@ class GemmLocalOptimizer(LocalOptimizer): ...@@ -403,17 +403,29 @@ class GemmLocalOptimizer(LocalOptimizer):
rval = beta_L_plus_alpha_M(sL, mL, -sR, mR) rval = beta_L_plus_alpha_M(sL, mL, -sR, mR)
return rval return rval
if node.op == T.add: if node.op == T.add:
# arguments of the form scalar * matrix
sM_list = [] sM_list = []
# arguments that can be interpreted as scalar * matrix
sM_orig = []
# arguments not of the form scalar * matrix (i.e., vectors, scalars)
other_inputs = [] other_inputs = []
for input in node.inputs: for input in node.inputs:
tmp = _as_isolated_scalar_times_matrix(input) tmp = _as_isolated_scalar_times_matrix(input)
if tmp: if tmp:
sM_list.append(tmp) sM_list.append(tmp)
sM_orig.append(input)
elif _is_real_matrix(input): elif _is_real_matrix(input):
sM_list.append((1.0, input)) sM_list.append((1.0, input))
sM_orig.append(input)
else: else:
other_inputs.append(input) other_inputs.append(input)
assert len(sM_list) == len(sM_orig)
assert len(sM_list) + len(other_inputs) == len(node.inputs)
if len(sM_list) == 2: if len(sM_list) == 2:
(sL, mL), (sR, mR) = sM_list (sL, mL), (sR, mR) = sM_list
gemm_of_sM_list = beta_L_plus_alpha_M(sL, mL, sR, mR) gemm_of_sM_list = beta_L_plus_alpha_M(sL, mL, sR, mR)
...@@ -425,16 +437,34 @@ class GemmLocalOptimizer(LocalOptimizer): ...@@ -425,16 +437,34 @@ class GemmLocalOptimizer(LocalOptimizer):
else: else:
return gemm_of_sM_list return gemm_of_sM_list
else: else:
# Try every pair in the sM_list, trying to turn it into a gemm operation
for i in xrange(len(sM_list) - 1): for i in xrange(len(sM_list) - 1):
for j in xrange(i+1, len(sM_list)): for j in xrange(i+1, len(sM_list)):
assert i != j
sL, mL = sM_list[i] sL, mL = sM_list[i]
sR, mR = sM_list[j] sR, mR = sM_list[j]
gemm_of_sM_list = beta_L_plus_alpha_M(sL, mL, sR, mR) gemm_of_sM_list = beta_L_plus_alpha_M(sL, mL, sR, mR)
if gemm_of_sM_list: if gemm_of_sM_list:
assert len(gemm_of_sM_list) == 1 assert len(gemm_of_sM_list) == 1
inputs_without_ij = \ inputs_without_ij = \
[input for k, input in enumerate(node.inputs) if k not in (i,j)] [input for k, input in enumerate(sM_orig) if k not in (i,j)]
return [T.add( *(inputs_without_ij + gemm_of_sM_list + other_inputs))]
new_add_inputs = (inputs_without_ij + gemm_of_sM_list + other_inputs)
if False: #SUPER DEBUG MODE :(
if len(new_add_inputs) + 1 != len(node.inputs):
print 'inputs', node.inputs
print 'sM, other', sM_list, other_inputs
print 'i,j', i, j
print 'gemm', gemm_of_sM_list
print 'without ij', inputs_without_ij
print 'new inputs', new_add_inputs
sys.exit(1)
# this should be True because we've combined a pair of arguments
# into a single GEMM
assert len(new_add_inputs) + 1 == len(node.inputs)
return [T.add(*new_add_inputs)]
return False return False
@staticmethod @staticmethod
...@@ -443,7 +473,7 @@ class GemmLocalOptimizer(LocalOptimizer): ...@@ -443,7 +473,7 @@ class GemmLocalOptimizer(LocalOptimizer):
if not isinstance(exc, InconsistencyError): if not isinstance(exc, InconsistencyError):
traceback.print_exc() traceback.print_exc()
else: else:
#print 'GEMM caused cycle, forget it.' #print 'GEMM caused cycle, it happens.'
pass pass
@staticmethod @staticmethod
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论