提交 0d8dc459 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

fixed tests

上级 f6e05092
......@@ -9,7 +9,7 @@ from gof import \
Type, Generic, generic, \
object2, utils
from compile import function, eval_outputs, fast_compute, OpFromGraph
from compile import FunctionMaker, function, OpFromGraph #, eval_outputs, fast_compute
import tensor
import tensor_random
......
差异被折叠。
......@@ -358,8 +358,6 @@ def clone_get_equiv(i, o, copy_inputs_and_orphans = True):
else:
d[input] = input
for apply in io_toposort(i, o):
for input in apply.inputs:
if input not in d:
......@@ -374,6 +372,10 @@ def clone_get_equiv(i, o, copy_inputs_and_orphans = True):
for output, new_output in zip(apply.outputs, new_apply.outputs):
d[output] = new_output
for output in o:
if output not in d:
d[output] = output.clone()
return d
def general_toposort(r_out, deps, debug_print = False):
......
......@@ -113,11 +113,16 @@ from collections import deque
class RandomKit(SymbolicInputKit):
def __init__(self, name, value = None):
super(RandomKit, self).__init__(name)
self.value = value
def gen(self, op, *args, **kwargs):
r = gof.generic()
new_r, out = op(r, *args, **kwargs)
self.add_input(SymbolicInput(r, update = new_r))
out.rng = r
out.auto = self
return out
def distribute(self, value, indices, containers):
......@@ -135,7 +140,18 @@ class RandomKit(SymbolicInputKit):
def binomial(self, *args, **kwargs):
return self.gen(binomial, *args, **kwargs)
rk = RandomKit('rk')
def uniform(self, *args, **kwargs):
return self.gen(uniform, *args, **kwargs)
def normal(self, *args, **kwargs):
return self.gen(normal, *args, **kwargs)
def random_integers(self, *args, **kwargs):
return self.gen(random_integers, *args, **kwargs)
rk = RandomKit('rk', 0xBAD5EED)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论