提交 d5ae91e1 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

included module in theano.compile

上级 664e4261
...@@ -3,7 +3,7 @@ sys.path.insert(0, '..') ...@@ -3,7 +3,7 @@ sys.path.insert(0, '..')
import theano import theano
from theano import tensor as T from theano import tensor as T
from theano.tensor import nnet_ops from theano.tensor import nnet_ops
from theano.sandbox import module from theano.compile import module
from theano.sandbox import pprint from theano.sandbox import pprint
import numpy as N import numpy as N
......
...@@ -11,3 +11,6 @@ from io import * ...@@ -11,3 +11,6 @@ from io import *
import builders import builders
from builders import * from builders import *
import module
from module import *
import theano from .. import gof
from theano import gof, compile
from collections import defaultdict from collections import defaultdict
from itertools import chain from itertools import chain
from functools import partial from functools import partial
from theano.gof.utils import scratchpad
from copy import copy from copy import copy
import mode
import function_module as F
#from ..sandbox import pprint
def join(*args): def join(*args):
...@@ -137,7 +137,6 @@ class Member(_RComponent): ...@@ -137,7 +137,6 @@ class Member(_RComponent):
from theano.sandbox import pprint
class Method(Component): class Method(Component):
def __init__(self, inputs, outputs, updates = {}, kits = [], **kwupdates): def __init__(self, inputs, outputs, updates = {}, kits = [], **kwupdates):
...@@ -197,13 +196,13 @@ class Method(Component): ...@@ -197,13 +196,13 @@ class Method(Component):
else: else:
return gof.Container(r, storage = [None]) return gof.Container(r, storage = [None])
inputs = self.inputs inputs = self.inputs
inputs = [compile.In(result = input, inputs = [mode.In(result = input,
value = get_storage(input)) value = get_storage(input))
for input in inputs] for input in inputs]
inputs += [compile.In(result = k, inputs += [mode.In(result = k,
update = v, update = v,
value = get_storage(k, True), value = get_storage(k, True),
strict = True) strict = True)
for k, v in self.updates.iteritems()] for k, v in self.updates.iteritems()]
outputs = self.outputs outputs = self.outputs
_inputs = [x.result for x in inputs] _inputs = [x.result for x in inputs]
...@@ -211,10 +210,10 @@ class Method(Component): ...@@ -211,10 +210,10 @@ class Method(Component):
+ [x.update for x in inputs if getattr(x, 'update', False)], + [x.update for x in inputs if getattr(x, 'update', False)],
blockers = _inputs): blockers = _inputs):
if input not in _inputs and not isinstance(input, gof.Value): if input not in _inputs and not isinstance(input, gof.Value):
inputs += [compile.In(result = input, inputs += [mode.In(result = input,
value = get_storage(input, True))] value = get_storage(input, True))]
inputs += [(kit, get_storage(kit, True)) for kit in self.kits] inputs += [(kit, get_storage(kit, True)) for kit in self.kits]
return compile.function(inputs, outputs, mode) return F.function(inputs, outputs, mode)
def pretty(self, **kwargs): def pretty(self, **kwargs):
self.resolve_all() self.resolve_all()
...@@ -627,32 +626,6 @@ class KitComponent(Component): ...@@ -627,32 +626,6 @@ class KitComponent(Component):
return memo[self.kit] return memo[self.kit]
from .. import tensor as T
class RModule(FancyModule):
def __init__(self, components = {}, **kwcomponents):
super(RModule, self).__init__(components, **kwcomponents)
self.random = T.RandomKit('rkit')
self._components['_rkit'] = KitComponent(self.random)
def __wrapper__(self, x):
x = wrap(x)
if isinstance(x, Method):
x.kits += [self.random]
return x
def _instance_seed(self, inst, seed, recursive = True):
if recursive:
for path, c in self.flat_components_map(True):
if isinstance(c, RModule):
inst2 = inst
for name in path:
inst2 = inst2[name]
c._rkit.kit.distribute(seed, xrange(len(inst._rkit)), inst2._rkit)
else:
self._rkit.kit.distribute(seed, xrange(len(inst._rkit)), inst._rkit)
......
...@@ -5,7 +5,7 @@ import opt ...@@ -5,7 +5,7 @@ import opt
import raw_random import raw_random
from raw_random import \ from raw_random import \
RandomKit RandomKit, RModule
random = RandomKit('random') random = RandomKit('random')
......
...@@ -4,6 +4,7 @@ import basic as tensor ...@@ -4,6 +4,7 @@ import basic as tensor
import numpy import numpy
import functools import functools
from .. import compile
from ..compile import SymbolicInputKit, SymbolicInput from ..compile import SymbolicInputKit, SymbolicInput
from copy import copy from copy import copy
...@@ -231,3 +232,26 @@ class RandomKit(SymbolicInputKit): ...@@ -231,3 +232,26 @@ class RandomKit(SymbolicInputKit):
rk = RandomKit('rk', 0xBAD5EED) rk = RandomKit('rk', 0xBAD5EED)
class RModule(compile.FancyModule):
def __init__(self, components = {}, **kwcomponents):
super(RModule, self).__init__(components, **kwcomponents)
self.random = T.RandomKit('rkit')
self._components['_rkit'] = KitComponent(self.random)
def __wrapper__(self, x):
x = wrap(x)
if isinstance(x, compile.Method):
x.kits += [self.random]
return x
def _instance_seed(self, inst, seed, recursive = True):
if recursive:
for path, c in self.flat_components_map(True):
if isinstance(c, RModule):
inst2 = inst
for name in path:
inst2 = inst2[name]
c._rkit.kit.distribute(seed, xrange(len(inst._rkit)), inst2._rkit)
else:
self._rkit.kit.distribute(seed, xrange(len(inst._rkit)), inst._rkit)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论