提交 d2c14c70 authored 作者: James Bergstra's avatar James Bergstra

cleaned up code related to gemm bug fix

上级 b6b2c608
...@@ -442,7 +442,8 @@ class GemmLocalOptimizer(LocalOptimizer): ...@@ -442,7 +442,8 @@ class GemmLocalOptimizer(LocalOptimizer):
if not isinstance(exc, InconsistencyError): if not isinstance(exc, InconsistencyError):
traceback.print_exc() traceback.print_exc()
else: else:
print 'GEMM caused cycle, forget it.' #print 'GEMM caused cycle, forget it.'
pass
@staticmethod @staticmethod
def _as_scalar(res): def _as_scalar(res):
...@@ -497,16 +498,15 @@ class GemmLocalOptimizer(LocalOptimizer): ...@@ -497,16 +498,15 @@ class GemmLocalOptimizer(LocalOptimizer):
def beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True): def beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True):
#print 'BETA L + ALPHA M', beta, L, alpha, M, recurse_flip #print 'BETA L + ALPHA M', beta, L, alpha, M, recurse_flip
#EXPRESSION: (beta * L) + (alpha * M) #EXPRESSION: (beta * L) + (alpha * M)
if True:
if res_is_a(L, T.sqrt):
print 'CLIENTS OF L', L, L.clients
if res_is_a(M, _dot22, 1): if res_is_a(M, _dot22, 1):
Ml, Mr = M.owner.inputs Ml, Mr = M.owner.inputs
rval = [gemm(L, alpha, Ml, Mr, beta)] rval = [gemm(L, alpha, Ml, Mr, beta)]
print 'GEMM 0', rval, beta, L, alpha, M #print 'GEMM 0', rval, beta, L, alpha, M
return rval return rval
# this is False'd out because of inadequate testing.
# TODO see ticket #237
if False and res_is_a(M, gemm, 1): if False and res_is_a(M, gemm, 1):
#EXPRESSION: (beta * L) + (alpha * (gemm(G, a, u, v, b))) #EXPRESSION: (beta * L) + (alpha * (gemm(G, a, u, v, b)))
#EXPRESSION: (beta * L) + alpha * (b * G) + alpha * a * dot(u, v) #EXPRESSION: (beta * L) + alpha * (b * G) + alpha * a * dot(u, v)
......
import traceback
import theano.tensor as T import theano.tensor as T
from ...gof import Env from ...gof import Env
import numpy import numpy
...@@ -258,11 +258,11 @@ def just_gemm(i, o, ishapes = [(4,3), (3,5), (4,5), (), ()]): ...@@ -258,11 +258,11 @@ def just_gemm(i, o, ishapes = [(4,3), (3,5), (4,5), (), ()]):
try: try:
f = function([In(ii, mutable=True) for ii in i],o, mode='FAST_RUN') f = function([In(ii, mutable=True) for ii in i],o, mode='FAST_RUN')
for node in f.maker.env.nodes: for node in f.maker.env.nodes:
if node.op == T.dot: raise Warning('dot in graph') if node.op == T.dot: raise Warning('dot not changed to gemm in graph')
if node.op == _dot22: raise Warning('_dot22 in graph') if node.op == _dot22: raise Warning('_dot22 not changed to gemm in graph')
g = function(i, o, mode=compile.Mode(linker='py', optimizer=None)) g = function(i, o, mode=compile.Mode(linker='py', optimizer=None))
for node in g.maker.env.nodes: for node in g.maker.env.nodes:
if node.op == gemm: raise Warning('gemm in graph') if node.op == gemm: raise Exception('gemm in original graph')
rng = numpy.random.RandomState(234) rng = numpy.random.RandomState(234)
r0 = f(*[rng.randn(*sh) for sh in ishapes]) r0 = f(*[rng.randn(*sh) for sh in ishapes])
...@@ -275,9 +275,11 @@ def just_gemm(i, o, ishapes = [(4,3), (3,5), (4,5), (), ()]): ...@@ -275,9 +275,11 @@ def just_gemm(i, o, ishapes = [(4,3), (3,5), (4,5), (), ()]):
for node in f.maker.env.toposort(): for node in f.maker.env.toposort():
print 'GRAPH', node print 'GRAPH', node
raise raise
except Warning: except Warning, e:
for node in f.maker.env.toposort(): #for node in f.maker.env.toposort():
print 'GRAPH', node # print 'GRAPH', node
print 'WARNING:', e
#traceback.print_exc()
def test_gemm_opt0(): def test_gemm_opt0():
......
...@@ -457,13 +457,7 @@ def create(window_size=3, ...@@ -457,13 +457,7 @@ def create(window_size=3,
model = architecture.make(input_size=input_dimension, input_representation_size=token_representation_size, hidden_representation_size=concatenated_representation_size, output_size=output_vocabsize, lr=lr, seed=seed, noise_level=noise_level, qfilter_relscale=qfilter_relscale, mode=compile_mode) model = architecture.make(input_size=input_dimension, input_representation_size=token_representation_size, hidden_representation_size=concatenated_representation_size, output_size=output_vocabsize, lr=lr, seed=seed, noise_level=noise_level, qfilter_relscale=qfilter_relscale, mode=compile_mode)
return model return model
from theano import gof def test_naacl_model(optimizer='fast_run'):
JTEST = theano.compile.mode.optdb.query(*sys.argv[2:])
print 'JTEST', JTEST
theano.compile.register_optimizer('JTEST', JTEST)
if __name__ == '__main__':
optimizer = eval(sys.argv[1])
m = create(compile_mode = theano.Mode(linker='c|py', optimizer=optimizer)) m = create(compile_mode = theano.Mode(linker='c|py', optimizer=optimizer))
prog_str = [] prog_str = []
idx_of_node = {} idx_of_node = {}
...@@ -488,11 +482,24 @@ if __name__ == '__main__': ...@@ -488,11 +482,24 @@ if __name__ == '__main__':
for i in xrange(10): for i in xrange(10):
for i in xrange(10): for i in xrange(10):
m.pretraining_update(*inputs) m.pretraining_update(*inputs)
print m.pretraining_update(*inputs) s0, s1 = [str(i) for i in m.pretraining_update(*inputs)]
print s0, s1
if s0 + ' ' + s1 != '0.315775007436 0.132479386981':
raise ValueError('pretraining update values do not match')
print 'FINETUNING GRAPH' print 'FINETUNING GRAPH'
print 'SUPERVISED PHASE COSTS (%s)'%optimizer print 'SUPERVISED PHASE COSTS (%s)'%optimizer
for i in xrange(10): for i in xrange(10):
for i in xrange(10): for i in xrange(10):
m.finetuning_update(*(inputs + [targets])) #the 0 is the target m.finetuning_update(*(inputs + [targets]))
print m.finetuning_update(*(inputs + [targets])) #the 0 is the target s0 = str(m.finetuning_update(*(inputs + [targets])))
print s0
if s0 != '15.8609933666':
raise ValueError('finetuning values do not match')
if __name__ == '__main__':
from theano import gof
JTEST = theano.compile.mode.optdb.query(*sys.argv[2:])
print 'JTEST', JTEST
theano.compile.register_optimizer('JTEST', JTEST)
optimizer = eval(sys.argv[1])
test_naacl_model(optimizer)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论