提交 690b487a authored 作者: Olivier Breuleux's avatar Olivier Breuleux

merge

......@@ -293,8 +293,12 @@ def _pickle_Function(f):
defaults = []
for (input, indices, inputs), (required, refeed, default) in zip(f.indices, f.defaults):
if isinstance(input, SymbolicInputKit):
li = len(indices)
if not default:
defaults.append(ins[:li])
else:
defaults.append(default)
ins[:len(indices)] = []
ins[:li] = []
else:
defaults.append(ins[0])
del ins[0]
......
......@@ -36,6 +36,8 @@ def check_equal_numpy(x, y):
"""
if isinstance(x, numpy.ndarray) and isinstance(y, numpy.ndarray):
return x.dtype == y.dtype and x.shape == y.shape and numpy.any(abs(x - y) < 1e-10)
elif isinstance(x, numpy.random.RandomState) and isinstance(y, numpy.random.RandomState):
return all(numpy.all(a==b) for a, b in zip(x.__getstate__(), y.__getstate__()))
else:
return x == y
......
......@@ -23,10 +23,11 @@ def out2in(*local_opts):
order = 'out_to_in',
failure_callback = lambda exc,opt,pairs: None)
def in2out(*local_opts):
def in2out(*local_opts, **kwargs):
return opt.TopoOptimizer(opt.LocalOptGroup(*local_opts),
order = 'in_to_out',
failure_callback = lambda exc,opt,pairs: None)
failure_callback = lambda exc,opt,pairs: None,
**kwargs)
# gemm: (d,a,b,c,s) -> d = d*s + a*dot(b,c)
......
......@@ -184,8 +184,9 @@ def random_make_inplace(node):
op = node.op
if isinstance(op, RandomFunction) and not op.inplace:
return RandomFunction(op.fn, op.outtype, *op.args, **dict(inplace=True)).make_node(*node.inputs).outputs
return False
compile.optdb.register('random_make_inplace', opt.in2out(random_make_inplace), 99, 'fast_run', 'inplace')
compile.optdb.register('random_make_inplace', opt.in2out(random_make_inplace, ignore_newtrees=True), 99, 'fast_run', 'inplace')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论