提交 3cefe303 authored 作者: James Bergstra's avatar James Bergstra

upgraded constant-folding optimization to use make_thunk

上级 bcea28c3
...@@ -3056,30 +3056,33 @@ def constant_folding(node): ...@@ -3056,30 +3056,33 @@ def constant_folding(node):
for input in node.inputs: for input in node.inputs:
if not isinstance(input, Constant): if not isinstance(input, Constant):
return False return False
try: #condition: all inputs are constant
storage = [[None] for output in node.outputs]
node.op.perform(node, [x.data for x in node.inputs], storage) storage_map=dict([(i,[i.data]) for i in node.inputs])
except MethodNotDefined: compute_map=dict([(i,[True]) for i in node.inputs])
tmp_inputs = [x.type() for x in node.inputs] for o in node.outputs:
f = compile.function( storage_map[o] = [None]
inputs=tmp_inputs, compute_map[o] = [False]
outputs=node.op.make_node(*tmp_inputs).outputs,
mode=compile.Mode(linker='c|py',optimizer=None)) thunk = node.op.make_thunk(node, storage_map, compute_map,
xvals = f(*[x.data for x in node.inputs]) no_recycling=[])
storage = [[xv] for xv in xvals]
required = thunk()
msg = [] assert not required # a node whose inputs are all provided should always
assert len(storage) == len(node.outputs) # return successfully
for s, output in zip(storage, node.outputs):
rval = []
for output in node.outputs:
assert compute_map[output][0], (output, storage_map[output][0])
try: try:
constant = output.type.Constant constant = output.type.Constant
except: except AttributeError:
constant = Constant constant = Constant
msg += [constant(output.type, s[0])] rval.append(constant(output.type, storage_map[output][0]))
return msg return rval
register_canonicalize(constant_folding, 'fast_compile') register_canonicalize(constant_folding, 'fast_compile')
register_stabilize(constant_folding) # because register_stabilize(constant_folding)
register_specialize(constant_folding) register_specialize(constant_folding)
def _is_1(expr): def _is_1(expr):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论