提交 d5ae91e1 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

included module in theano.compile

上级 664e4261
......@@ -3,7 +3,7 @@ sys.path.insert(0, '..')
import theano
from theano import tensor as T
from theano.tensor import nnet_ops
from theano.sandbox import module
from theano.compile import module
from theano.sandbox import pprint
import numpy as N
......
......@@ -11,3 +11,6 @@ from io import *
import builders
from builders import *
import module
from module import *
import theano
from theano import gof, compile
from .. import gof
from collections import defaultdict
from itertools import chain
from functools import partial
from theano.gof.utils import scratchpad
from copy import copy
import mode
import function_module as F
#from ..sandbox import pprint
def join(*args):
......@@ -137,7 +137,6 @@ class Member(_RComponent):
from theano.sandbox import pprint
class Method(Component):
def __init__(self, inputs, outputs, updates = {}, kits = [], **kwupdates):
......@@ -197,13 +196,13 @@ class Method(Component):
else:
return gof.Container(r, storage = [None])
inputs = self.inputs
inputs = [compile.In(result = input,
value = get_storage(input))
inputs = [mode.In(result = input,
value = get_storage(input))
for input in inputs]
inputs += [compile.In(result = k,
update = v,
value = get_storage(k, True),
strict = True)
inputs += [mode.In(result = k,
update = v,
value = get_storage(k, True),
strict = True)
for k, v in self.updates.iteritems()]
outputs = self.outputs
_inputs = [x.result for x in inputs]
......@@ -211,10 +210,10 @@ class Method(Component):
+ [x.update for x in inputs if getattr(x, 'update', False)],
blockers = _inputs):
if input not in _inputs and not isinstance(input, gof.Value):
inputs += [compile.In(result = input,
value = get_storage(input, True))]
inputs += [mode.In(result = input,
value = get_storage(input, True))]
inputs += [(kit, get_storage(kit, True)) for kit in self.kits]
return compile.function(inputs, outputs, mode)
return F.function(inputs, outputs, mode)
def pretty(self, **kwargs):
self.resolve_all()
......@@ -627,32 +626,6 @@ class KitComponent(Component):
return memo[self.kit]
from .. import tensor as T
class RModule(FancyModule):
def __init__(self, components = {}, **kwcomponents):
super(RModule, self).__init__(components, **kwcomponents)
self.random = T.RandomKit('rkit')
self._components['_rkit'] = KitComponent(self.random)
def __wrapper__(self, x):
x = wrap(x)
if isinstance(x, Method):
x.kits += [self.random]
return x
def _instance_seed(self, inst, seed, recursive = True):
if recursive:
for path, c in self.flat_components_map(True):
if isinstance(c, RModule):
inst2 = inst
for name in path:
inst2 = inst2[name]
c._rkit.kit.distribute(seed, xrange(len(inst._rkit)), inst2._rkit)
else:
self._rkit.kit.distribute(seed, xrange(len(inst._rkit)), inst._rkit)
......
......@@ -5,7 +5,7 @@ import opt
import raw_random
from raw_random import \
RandomKit
RandomKit, RModule
random = RandomKit('random')
......
......@@ -4,6 +4,7 @@ import basic as tensor
import numpy
import functools
from .. import compile
from ..compile import SymbolicInputKit, SymbolicInput
from copy import copy
......@@ -231,3 +232,26 @@ class RandomKit(SymbolicInputKit):
rk = RandomKit('rk', 0xBAD5EED)
class RModule(compile.FancyModule):
def __init__(self, components = {}, **kwcomponents):
super(RModule, self).__init__(components, **kwcomponents)
self.random = T.RandomKit('rkit')
self._components['_rkit'] = KitComponent(self.random)
def __wrapper__(self, x):
x = wrap(x)
if isinstance(x, compile.Method):
x.kits += [self.random]
return x
def _instance_seed(self, inst, seed, recursive = True):
if recursive:
for path, c in self.flat_components_map(True):
if isinstance(c, RModule):
inst2 = inst
for name in path:
inst2 = inst2[name]
c._rkit.kit.distribute(seed, xrange(len(inst._rkit)), inst2._rkit)
else:
self._rkit.kit.distribute(seed, xrange(len(inst._rkit)), inst._rkit)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论