提交 8d623587 authored 作者: James Bergstra's avatar James Bergstra

Fixes #223 and the potential optdb registry bug"

上级 5a1fd228
...@@ -8,7 +8,7 @@ import theano ...@@ -8,7 +8,7 @@ import theano
import theano.tensor as T import theano.tensor as T
import theano.sandbox import theano.sandbox
import theano.sandbox.wraplinker import theano.sandbox.wraplinker
from theano.compile import module from theano.compile import module, Mode
if 0: if 0:
class Opt(object): class Opt(object):
...@@ -130,32 +130,30 @@ if 0: ...@@ -130,32 +130,30 @@ if 0:
self.merge(env) self.merge(env)
def linker(print_prog=False): def linker(print_prog=True):
if 1: if 1:
print 'wtf?' imap = {None:'-'}
#return theano.gof.OpWiseCLinker() def blah(i, node, thunk):
imap = {None:'-'} imap[node] = str(i)
def blah(i, node, thunk): if print_prog:# and node.op.__class__ is T.DimShuffle:
imap[node] = str(i) if False and node.op == T.DimShuffle((), ['x', 'x'], inplace = True):
if print_prog:# and node.op.__class__ is T.DimShuffle: print node.op == T.DimShuffle((), ['x', 'x'], inplace = True),
if False and node.op == T.DimShuffle((), ['x', 'x'], inplace = True): print node.inputs[0], type(node.inputs[0]),
print node.op == T.DimShuffle((), ['x', 'x'], inplace = True), print node.inputs[0].equals(T.constant(2)),
print node.inputs[0], type(node.inputs[0]), outputs = node.outputs
print node.inputs[0].equals(T.constant(2)), inputs = theano.gof.graph.inputs(outputs)
outputs = node.outputs print 'node ', i, node,
inputs = theano.gof.graph.inputs(outputs) print ':'.join([imap[inp.owner] for inp in node.inputs])
print 'node ', i, node, #print theano.sandbox.pprint.pp.process_graph(inputs, outputs)
print ':'.join([imap[inp.owner] for inp in node.inputs])
#print theano.sandbox.pprint.pp.process_graph(inputs, outputs) return theano.sandbox.wraplinker.WrapLinkerMany(
[theano.gof.OpWiseCLinker()],
return theano.sandbox.wraplinker.WrapLinkerMany( [theano.sandbox.wraplinker.run_all
[theano.gof.OpWiseCLinker()], ,blah
[theano.sandbox.wraplinker.run_all #,theano.sandbox.wraplinker.numpy_notall_isfinite
,blah ])
#,theano.sandbox.wraplinker.numpy_notall_isfinite else:
]) return theano.gof.OpWiseCLinker()
else:
return theano.gof.OpWiseCLinker()
class M(module.Module): class M(module.Module):
...@@ -167,11 +165,14 @@ class M(module.Module): ...@@ -167,11 +165,14 @@ class M(module.Module):
self.a = module.Member(T.vector('a')) # hid bias self.a = module.Member(T.vector('a')) # hid bias
self.b = module.Member(T.vector('b')) # output bias self.b = module.Member(T.vector('b')) # output bias
hid = T.tanh(T.dot(x, self.w) + self.a) self.hid = T.tanh(T.dot(x, self.w) + self.a)
hid = self.hid
out = T.tanh(T.dot(hid, self.w.T) + self.b) self.out = T.tanh(T.dot(hid, self.w.T) + self.b)
out = self.out
err = 0.5 * T.sum((out - x)**2) self.err = 0.5 * T.sum((out - x)**2)
err = self.err
params = [self.w, self.a, self.b] params = [self.w, self.a, self.b]
...@@ -182,7 +183,8 @@ class M(module.Module): ...@@ -182,7 +183,8 @@ class M(module.Module):
self.step = module.Method([x], err, updates=dict(updates)) self.step = module.Method([x], err, updates=dict(updates))
mod = M() mod = M()
m = mod.make(mode='FAST_RUN') #m = mod.make(mode='FAST_RUN')
m = mod.make(mode=Mode(optimizer='fast_run', linker=linker()))
neg, nout, nhid, niter = [int(a) for a in sys.argv[1:]] neg, nout, nhid, niter = [int(a) for a in sys.argv[1:]]
rng = numpy.random.RandomState(342) rng = numpy.random.RandomState(342)
......
#!/usr/bin/env python2.5
from __future__ import absolute_import
import numpy as N
import sys
import time
neg, nout, nhid, niter = [int(a) for a in sys.argv[1:]]
lr = 0.01
rng = N.random.RandomState(342)
w = rng.rand(nout, nhid)
a = rng.randn(nhid) * 0.0
b = rng.randn(nout) * 0.0
x = (rng.rand(neg, nout)-0.5) * 1.5
t = time.time()
for i in xrange(niter):
hid = N.tanh(N.dot(x, w) + a)
out = N.tanh(N.dot(hid, w.T) + b)
g_out = out - x
err = 0.5 * N.sum(g_out**2)
g_hidwt = g_out * (1.0 - out**2)
b -= lr * N.sum(g_hidwt, axis=0)
g_hid = N.dot(g_hidwt, w)
g_hidin = g_hid * (1.0 - hid**2)
w -= lr * (N.dot(g_hidwt.T, hid) + N.dot(x.T, g_hidin))
a -= lr * N.sum(g_hidin, axis=0)
print 'time: ',time.time() - t, 'err: ', err
...@@ -15,6 +15,7 @@ from collections import deque, defaultdict ...@@ -15,6 +15,7 @@ from collections import deque, defaultdict
import destroyhandler as dh import destroyhandler as dh
import sys import sys
_optimizer_idx = [0]
class Optimizer(object): class Optimizer(object):
"""WRITEME """WRITEME
...@@ -23,6 +24,12 @@ class Optimizer(object): ...@@ -23,6 +24,12 @@ class Optimizer(object):
of transformation you could apply to an L{Env}. of transformation you could apply to an L{Env}.
""" """
def __hash__(self):
if not hasattr(self, '_optimizer_idx'):
self._optimizer_idx = _optimizer_idx[0]
_optimizer_idx[0] += 1
return self._optimizer_idx
def apply(self, env): def apply(self, env):
"""WRITEME """WRITEME
Applies the optimization to the provided L{Env}. It may use all Applies the optimization to the provided L{Env}. It may use all
...@@ -72,6 +79,7 @@ def optimizer(f): ...@@ -72,6 +79,7 @@ def optimizer(f):
class SeqOptimizer(Optimizer, list): class SeqOptimizer(Optimizer, list):
#inherit from Optimizer first to get Optimizer.__hash__
"""WRITEME """WRITEME
Takes a list of L{Optimizer} instances and applies them Takes a list of L{Optimizer} instances and applies them
sequentially. sequentially.
...@@ -99,13 +107,13 @@ class SeqOptimizer(Optimizer, list): ...@@ -99,13 +107,13 @@ class SeqOptimizer(Optimizer, list):
raise raise
def __eq__(self, other): def __eq__(self, other):
#added to override the list's __eq__ implementation
return id(self) == id(other) return id(self) == id(other)
def __neq__(self, other): def __neq__(self, other):
#added to override the list's __neq__ implementation
return id(self) != id(other) return id(self) != id(other)
def __hash__(self):
return hash(id(self))
def __str__(self): def __str__(self):
return "SeqOpt(%s)" % list.__str__(self) return "SeqOpt(%s)" % list.__str__(self)
...@@ -221,7 +229,7 @@ def MergeOptMerge(opt): ...@@ -221,7 +229,7 @@ def MergeOptMerge(opt):
### Local Optimizers ### ### Local Optimizers ###
######################## ########################
class LocalOptimizer(utils.object2): class LocalOptimizer(Optimizer, utils.object2):
"""WRITEME""" """WRITEME"""
def transform(self, node): def transform(self, node):
......
...@@ -4,16 +4,31 @@ import opt ...@@ -4,16 +4,31 @@ import opt
class DB(object): class DB(object):
def __hash__(self):
if not hasattr(self, '_optimizer_idx'):
self._optimizer_idx = opt._optimizer_idx[0]
opt._optimizer_idx[0] += 1
return self._optimizer_idx
def __init__(self): def __init__(self):
self.__db__ = defaultdict(set) self.__db__ = defaultdict(set)
self._names = set()
def register(self, name, obj, *tags): def register(self, name, obj, *tags):
# N.B. obj is not an instance of class Optimizer.
# It is an instance of a DB.In the tests for example,
# this is not always the case.
if not isinstance(obj, (DB, opt.Optimizer)):
raise Exception('wtf', obj)
obj.name = name obj.name = name
if name in self.__db__: if name in self.__db__:
raise ValueError('The name of the object cannot be an existing tag or the name of another existing object.', obj, name) raise ValueError('The name of the object cannot be an existing tag or the name of another existing object.', obj, name)
self.__db__[name] = set([obj]) self.__db__[name] = set([obj])
self._names.add(name)
for tag in tags: for tag in tags:
if tag in self._names:
raise ValueError('The tag of the object collides with a name.', obj, tag)
self.__db__[tag].add(obj) self.__db__[tag].add(obj)
def __query__(self, q): def __query__(self, q):
......
from theano.gof.optdb import *
from unittest import TestCase
class Test_DB(TestCase):
def test_0(self):
class Opt(opt.Optimizer): #inheritance buys __hash__
name = 'blah'
db = DB()
db.register('a', Opt())
db.register('b', Opt())
db.register('c', Opt(), 'z', 'asdf')
try:
db.register('c', Opt()) #name taken
self.fail()
except ValueError, e:
if e[0].startswith("The name"):
pass
else:
raise
except:
self.fail()
try:
db.register('z', Opt()) #name collides with tag
self.fail()
except ValueError, e:
if e[0].startswith("The name"):
pass
else:
raise
except:
self.fail()
try:
db.register('u', Opt(), 'b') #name new but tag collides with name
self.fail()
except ValueError, e:
if e[0].startswith("The tag"):
pass
else:
raise
except:
self.fail()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论