提交 0d8dc459 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

fixed tests

上级 f6e05092
...@@ -9,7 +9,7 @@ from gof import \ ...@@ -9,7 +9,7 @@ from gof import \
Type, Generic, generic, \ Type, Generic, generic, \
object2, utils object2, utils
from compile import function, eval_outputs, fast_compute, OpFromGraph from compile import FunctionMaker, function, OpFromGraph #, eval_outputs, fast_compute
import tensor import tensor
import tensor_random import tensor_random
......
差异被折叠。
...@@ -358,8 +358,6 @@ def clone_get_equiv(i, o, copy_inputs_and_orphans = True): ...@@ -358,8 +358,6 @@ def clone_get_equiv(i, o, copy_inputs_and_orphans = True):
else: else:
d[input] = input d[input] = input
for apply in io_toposort(i, o): for apply in io_toposort(i, o):
for input in apply.inputs: for input in apply.inputs:
if input not in d: if input not in d:
...@@ -374,6 +372,10 @@ def clone_get_equiv(i, o, copy_inputs_and_orphans = True): ...@@ -374,6 +372,10 @@ def clone_get_equiv(i, o, copy_inputs_and_orphans = True):
for output, new_output in zip(apply.outputs, new_apply.outputs): for output, new_output in zip(apply.outputs, new_apply.outputs):
d[output] = new_output d[output] = new_output
for output in o:
if output not in d:
d[output] = output.clone()
return d return d
def general_toposort(r_out, deps, debug_print = False): def general_toposort(r_out, deps, debug_print = False):
......
...@@ -113,11 +113,16 @@ from collections import deque ...@@ -113,11 +113,16 @@ from collections import deque
class RandomKit(SymbolicInputKit): class RandomKit(SymbolicInputKit):
def __init__(self, name, value = None):
super(RandomKit, self).__init__(name)
self.value = value
def gen(self, op, *args, **kwargs): def gen(self, op, *args, **kwargs):
r = gof.generic() r = gof.generic()
new_r, out = op(r, *args, **kwargs) new_r, out = op(r, *args, **kwargs)
self.add_input(SymbolicInput(r, update = new_r)) self.add_input(SymbolicInput(r, update = new_r))
out.rng = r out.rng = r
out.auto = self
return out return out
def distribute(self, value, indices, containers): def distribute(self, value, indices, containers):
...@@ -135,7 +140,18 @@ class RandomKit(SymbolicInputKit): ...@@ -135,7 +140,18 @@ class RandomKit(SymbolicInputKit):
def binomial(self, *args, **kwargs): def binomial(self, *args, **kwargs):
return self.gen(binomial, *args, **kwargs) return self.gen(binomial, *args, **kwargs)
rk = RandomKit('rk') def uniform(self, *args, **kwargs):
return self.gen(uniform, *args, **kwargs)
def normal(self, *args, **kwargs):
return self.gen(normal, *args, **kwargs)
def random_integers(self, *args, **kwargs):
return self.gen(random_integers, *args, **kwargs)
rk = RandomKit('rk', 0xBAD5EED)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论