提交 8d623587 authored 作者: James Bergstra's avatar James Bergstra

Fixes #223 and the potential optdb registry bug"

上级 5a1fd228
......@@ -8,7 +8,7 @@ import theano
import theano.tensor as T
import theano.sandbox
import theano.sandbox.wraplinker
from theano.compile import module
from theano.compile import module, Mode
if 0:
class Opt(object):
......@@ -130,32 +130,30 @@ if 0:
self.merge(env)
def linker(print_prog=False):
if 1:
print 'wtf?'
#return theano.gof.OpWiseCLinker()
imap = {None:'-'}
def blah(i, node, thunk):
imap[node] = str(i)
if print_prog:# and node.op.__class__ is T.DimShuffle:
if False and node.op == T.DimShuffle((), ['x', 'x'], inplace = True):
print node.op == T.DimShuffle((), ['x', 'x'], inplace = True),
print node.inputs[0], type(node.inputs[0]),
print node.inputs[0].equals(T.constant(2)),
outputs = node.outputs
inputs = theano.gof.graph.inputs(outputs)
print 'node ', i, node,
print ':'.join([imap[inp.owner] for inp in node.inputs])
#print theano.sandbox.pprint.pp.process_graph(inputs, outputs)
return theano.sandbox.wraplinker.WrapLinkerMany(
[theano.gof.OpWiseCLinker()],
[theano.sandbox.wraplinker.run_all
,blah
#,theano.sandbox.wraplinker.numpy_notall_isfinite
])
else:
return theano.gof.OpWiseCLinker()
def linker(print_prog=True):
if 1:
imap = {None:'-'}
def blah(i, node, thunk):
imap[node] = str(i)
if print_prog:# and node.op.__class__ is T.DimShuffle:
if False and node.op == T.DimShuffle((), ['x', 'x'], inplace = True):
print node.op == T.DimShuffle((), ['x', 'x'], inplace = True),
print node.inputs[0], type(node.inputs[0]),
print node.inputs[0].equals(T.constant(2)),
outputs = node.outputs
inputs = theano.gof.graph.inputs(outputs)
print 'node ', i, node,
print ':'.join([imap[inp.owner] for inp in node.inputs])
#print theano.sandbox.pprint.pp.process_graph(inputs, outputs)
return theano.sandbox.wraplinker.WrapLinkerMany(
[theano.gof.OpWiseCLinker()],
[theano.sandbox.wraplinker.run_all
,blah
#,theano.sandbox.wraplinker.numpy_notall_isfinite
])
else:
return theano.gof.OpWiseCLinker()
class M(module.Module):
......@@ -167,11 +165,14 @@ class M(module.Module):
self.a = module.Member(T.vector('a')) # hid bias
self.b = module.Member(T.vector('b')) # output bias
hid = T.tanh(T.dot(x, self.w) + self.a)
self.hid = T.tanh(T.dot(x, self.w) + self.a)
hid = self.hid
out = T.tanh(T.dot(hid, self.w.T) + self.b)
self.out = T.tanh(T.dot(hid, self.w.T) + self.b)
out = self.out
err = 0.5 * T.sum((out - x)**2)
self.err = 0.5 * T.sum((out - x)**2)
err = self.err
params = [self.w, self.a, self.b]
......@@ -182,7 +183,8 @@ class M(module.Module):
self.step = module.Method([x], err, updates=dict(updates))
mod = M()
m = mod.make(mode='FAST_RUN')
#m = mod.make(mode='FAST_RUN')
m = mod.make(mode=Mode(optimizer='fast_run', linker=linker()))
neg, nout, nhid, niter = [int(a) for a in sys.argv[1:]]
rng = numpy.random.RandomState(342)
......
#!/usr/bin/env python2.5
from __future__ import absolute_import
import numpy as N
import sys
import time
neg, nout, nhid, niter = [int(a) for a in sys.argv[1:]]
lr = 0.01
rng = N.random.RandomState(342)
w = rng.rand(nout, nhid)
a = rng.randn(nhid) * 0.0
b = rng.randn(nout) * 0.0
x = (rng.rand(neg, nout)-0.5) * 1.5
t = time.time()
for i in xrange(niter):
hid = N.tanh(N.dot(x, w) + a)
out = N.tanh(N.dot(hid, w.T) + b)
g_out = out - x
err = 0.5 * N.sum(g_out**2)
g_hidwt = g_out * (1.0 - out**2)
b -= lr * N.sum(g_hidwt, axis=0)
g_hid = N.dot(g_hidwt, w)
g_hidin = g_hid * (1.0 - hid**2)
w -= lr * (N.dot(g_hidwt.T, hid) + N.dot(x.T, g_hidin))
a -= lr * N.sum(g_hidin, axis=0)
print 'time: ',time.time() - t, 'err: ', err
......@@ -15,6 +15,7 @@ from collections import deque, defaultdict
import destroyhandler as dh
import sys
_optimizer_idx = [0]
class Optimizer(object):
"""WRITEME
......@@ -23,6 +24,12 @@ class Optimizer(object):
of transformation you could apply to an L{Env}.
"""
def __hash__(self):
if not hasattr(self, '_optimizer_idx'):
self._optimizer_idx = _optimizer_idx[0]
_optimizer_idx[0] += 1
return self._optimizer_idx
def apply(self, env):
"""WRITEME
Applies the optimization to the provided L{Env}. It may use all
......@@ -72,6 +79,7 @@ def optimizer(f):
class SeqOptimizer(Optimizer, list):
#inherit from Optimizer first to get Optimizer.__hash__
"""WRITEME
Takes a list of L{Optimizer} instances and applies them
sequentially.
......@@ -99,13 +107,13 @@ class SeqOptimizer(Optimizer, list):
raise
def __eq__(self, other):
#added to override the list's __eq__ implementation
return id(self) == id(other)
def __neq__(self, other):
#added to override the list's __neq__ implementation
return id(self) != id(other)
def __hash__(self):
return hash(id(self))
def __str__(self):
return "SeqOpt(%s)" % list.__str__(self)
......@@ -221,7 +229,7 @@ def MergeOptMerge(opt):
### Local Optimizers ###
########################
class LocalOptimizer(utils.object2):
class LocalOptimizer(Optimizer, utils.object2):
"""WRITEME"""
def transform(self, node):
......
......@@ -4,16 +4,31 @@ import opt
class DB(object):
def __hash__(self):
if not hasattr(self, '_optimizer_idx'):
self._optimizer_idx = opt._optimizer_idx[0]
opt._optimizer_idx[0] += 1
return self._optimizer_idx
def __init__(self):
self.__db__ = defaultdict(set)
self._names = set()
def register(self, name, obj, *tags):
# N.B. obj is not an instance of class Optimizer.
# It is an instance of a DB.In the tests for example,
# this is not always the case.
if not isinstance(obj, (DB, opt.Optimizer)):
raise Exception('wtf', obj)
obj.name = name
if name in self.__db__:
raise ValueError('The name of the object cannot be an existing tag or the name of another existing object.', obj, name)
self.__db__[name] = set([obj])
self._names.add(name)
for tag in tags:
if tag in self._names:
raise ValueError('The tag of the object collides with a name.', obj, tag)
self.__db__[tag].add(obj)
def __query__(self, q):
......
from theano.gof.optdb import *
from unittest import TestCase
class Test_DB(TestCase):
def test_0(self):
class Opt(opt.Optimizer): #inheritance buys __hash__
name = 'blah'
db = DB()
db.register('a', Opt())
db.register('b', Opt())
db.register('c', Opt(), 'z', 'asdf')
try:
db.register('c', Opt()) #name taken
self.fail()
except ValueError, e:
if e[0].startswith("The name"):
pass
else:
raise
except:
self.fail()
try:
db.register('z', Opt()) #name collides with tag
self.fail()
except ValueError, e:
if e[0].startswith("The name"):
pass
else:
raise
except:
self.fail()
try:
db.register('u', Opt(), 'b') #name new but tag collides with name
self.fail()
except ValueError, e:
if e[0].startswith("The tag"):
pass
else:
raise
except:
self.fail()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论