提交 d2c14c70 authored 作者: James Bergstra's avatar James Bergstra

cleaned up code related to gemm bug fix

上级 b6b2c608
......@@ -442,7 +442,8 @@ class GemmLocalOptimizer(LocalOptimizer):
if not isinstance(exc, InconsistencyError):
traceback.print_exc()
else:
print 'GEMM caused cycle, forget it.'
#print 'GEMM caused cycle, forget it.'
pass
@staticmethod
def _as_scalar(res):
......@@ -497,16 +498,15 @@ class GemmLocalOptimizer(LocalOptimizer):
def beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True):
#print 'BETA L + ALPHA M', beta, L, alpha, M, recurse_flip
#EXPRESSION: (beta * L) + (alpha * M)
if True:
if res_is_a(L, T.sqrt):
print 'CLIENTS OF L', L, L.clients
if res_is_a(M, _dot22, 1):
Ml, Mr = M.owner.inputs
rval = [gemm(L, alpha, Ml, Mr, beta)]
print 'GEMM 0', rval, beta, L, alpha, M
#print 'GEMM 0', rval, beta, L, alpha, M
return rval
# this is False'd out because of inadequate testing.
# TODO see ticket #237
if False and res_is_a(M, gemm, 1):
#EXPRESSION: (beta * L) + (alpha * (gemm(G, a, u, v, b)))
#EXPRESSION: (beta * L) + alpha * (b * G) + alpha * a * dot(u, v)
......
import traceback
import theano.tensor as T
from ...gof import Env
import numpy
......@@ -258,11 +258,11 @@ def just_gemm(i, o, ishapes = [(4,3), (3,5), (4,5), (), ()]):
try:
f = function([In(ii, mutable=True) for ii in i],o, mode='FAST_RUN')
for node in f.maker.env.nodes:
if node.op == T.dot: raise Warning('dot in graph')
if node.op == _dot22: raise Warning('_dot22 in graph')
if node.op == T.dot: raise Warning('dot not changed to gemm in graph')
if node.op == _dot22: raise Warning('_dot22 not changed to gemm in graph')
g = function(i, o, mode=compile.Mode(linker='py', optimizer=None))
for node in g.maker.env.nodes:
if node.op == gemm: raise Warning('gemm in graph')
if node.op == gemm: raise Exception('gemm in original graph')
rng = numpy.random.RandomState(234)
r0 = f(*[rng.randn(*sh) for sh in ishapes])
......@@ -275,9 +275,11 @@ def just_gemm(i, o, ishapes = [(4,3), (3,5), (4,5), (), ()]):
for node in f.maker.env.toposort():
print 'GRAPH', node
raise
except Warning:
for node in f.maker.env.toposort():
print 'GRAPH', node
except Warning, e:
#for node in f.maker.env.toposort():
# print 'GRAPH', node
print 'WARNING:', e
#traceback.print_exc()
def test_gemm_opt0():
......
......@@ -457,13 +457,7 @@ def create(window_size=3,
model = architecture.make(input_size=input_dimension, input_representation_size=token_representation_size, hidden_representation_size=concatenated_representation_size, output_size=output_vocabsize, lr=lr, seed=seed, noise_level=noise_level, qfilter_relscale=qfilter_relscale, mode=compile_mode)
return model
from theano import gof
JTEST = theano.compile.mode.optdb.query(*sys.argv[2:])
print 'JTEST', JTEST
theano.compile.register_optimizer('JTEST', JTEST)
if __name__ == '__main__':
optimizer = eval(sys.argv[1])
def test_naacl_model(optimizer='fast_run'):
m = create(compile_mode = theano.Mode(linker='c|py', optimizer=optimizer))
prog_str = []
idx_of_node = {}
......@@ -488,11 +482,24 @@ if __name__ == '__main__':
for i in xrange(10):
for i in xrange(10):
m.pretraining_update(*inputs)
print m.pretraining_update(*inputs)
s0, s1 = [str(i) for i in m.pretraining_update(*inputs)]
print s0, s1
if s0 + ' ' + s1 != '0.315775007436 0.132479386981':
raise ValueError('pretraining update values do not match')
print 'FINETUNING GRAPH'
print 'SUPERVISED PHASE COSTS (%s)'%optimizer
for i in xrange(10):
for i in xrange(10):
m.finetuning_update(*(inputs + [targets])) #the 0 is the target
print m.finetuning_update(*(inputs + [targets])) #the 0 is the target
m.finetuning_update(*(inputs + [targets]))
s0 = str(m.finetuning_update(*(inputs + [targets])))
print s0
if s0 != '15.8609933666':
raise ValueError('finetuning values do not match')
if __name__ == '__main__':
from theano import gof
JTEST = theano.compile.mode.optdb.query(*sys.argv[2:])
print 'JTEST', JTEST
theano.compile.register_optimizer('JTEST', JTEST)
optimizer = eval(sys.argv[1])
test_naacl_model(optimizer)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论