提交 130933ce authored 作者: Olivier Breuleux's avatar Olivier Breuleux

bug fix with kits

上级 5272fcf0
...@@ -293,8 +293,12 @@ def _pickle_Function(f): ...@@ -293,8 +293,12 @@ def _pickle_Function(f):
defaults = [] defaults = []
for (input, indices, inputs), (required, refeed, default) in zip(f.indices, f.defaults): for (input, indices, inputs), (required, refeed, default) in zip(f.indices, f.defaults):
if isinstance(input, SymbolicInputKit): if isinstance(input, SymbolicInputKit):
defaults.append(default) li = len(indices)
ins[:len(indices)] = [] if not default:
defaults.append(ins[:li])
else:
defaults.append(default)
ins[:li] = []
else: else:
defaults.append(ins[0]) defaults.append(ins[0])
del ins[0] del ins[0]
......
...@@ -36,6 +36,8 @@ def check_equal_numpy(x, y): ...@@ -36,6 +36,8 @@ def check_equal_numpy(x, y):
""" """
if isinstance(x, numpy.ndarray) and isinstance(y, numpy.ndarray): if isinstance(x, numpy.ndarray) and isinstance(y, numpy.ndarray):
return x.dtype == y.dtype and x.shape == y.shape and numpy.any(abs(x - y) < 1e-10) return x.dtype == y.dtype and x.shape == y.shape and numpy.any(abs(x - y) < 1e-10)
elif isinstance(x, numpy.random.RandomState) and isinstance(y, numpy.random.RandomState):
return all(numpy.all(a==b) for a, b in zip(x.__getstate__(), y.__getstate__()))
else: else:
return x == y return x == y
......
...@@ -23,10 +23,11 @@ def out2in(*local_opts): ...@@ -23,10 +23,11 @@ def out2in(*local_opts):
order = 'out_to_in', order = 'out_to_in',
failure_callback = lambda exc,opt,pairs: None) failure_callback = lambda exc,opt,pairs: None)
def in2out(*local_opts): def in2out(*local_opts, **kwargs):
return opt.TopoOptimizer(opt.LocalOptGroup(*local_opts), return opt.TopoOptimizer(opt.LocalOptGroup(*local_opts),
order = 'in_to_out', order = 'in_to_out',
failure_callback = lambda exc,opt,pairs: None) failure_callback = lambda exc,opt,pairs: None,
**kwargs)
# gemm: (d,a,b,c,s) -> d = d*s + a*dot(b,c) # gemm: (d,a,b,c,s) -> d = d*s + a*dot(b,c)
......
...@@ -184,8 +184,9 @@ def random_make_inplace(node): ...@@ -184,8 +184,9 @@ def random_make_inplace(node):
op = node.op op = node.op
if isinstance(op, RandomFunction) and not op.inplace: if isinstance(op, RandomFunction) and not op.inplace:
return RandomFunction(op.fn, op.outtype, *op.args, **dict(inplace=True)).make_node(*node.inputs).outputs return RandomFunction(op.fn, op.outtype, *op.args, **dict(inplace=True)).make_node(*node.inputs).outputs
return False
compile.optdb.register('random_make_inplace', opt.in2out(random_make_inplace), 99, 'fast_run', 'inplace') compile.optdb.register('random_make_inplace', opt.in2out(random_make_inplace, ignore_newtrees=True), 99, 'fast_run', 'inplace')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论