提交 3cefe303 authored 作者: James Bergstra's avatar James Bergstra

upgraded constant-folding optimization to use make_thunk

上级 bcea28c3
......@@ -3056,30 +3056,33 @@ def constant_folding(node):
for input in node.inputs:
if not isinstance(input, Constant):
return False
try:
storage = [[None] for output in node.outputs]
node.op.perform(node, [x.data for x in node.inputs], storage)
except MethodNotDefined:
tmp_inputs = [x.type() for x in node.inputs]
f = compile.function(
inputs=tmp_inputs,
outputs=node.op.make_node(*tmp_inputs).outputs,
mode=compile.Mode(linker='c|py',optimizer=None))
xvals = f(*[x.data for x in node.inputs])
storage = [[xv] for xv in xvals]
msg = []
assert len(storage) == len(node.outputs)
for s, output in zip(storage, node.outputs):
#condition: all inputs are constant
storage_map=dict([(i,[i.data]) for i in node.inputs])
compute_map=dict([(i,[True]) for i in node.inputs])
for o in node.outputs:
storage_map[o] = [None]
compute_map[o] = [False]
thunk = node.op.make_thunk(node, storage_map, compute_map,
no_recycling=[])
required = thunk()
assert not required # a node whose inputs are all provided should always
# return successfully
rval = []
for output in node.outputs:
assert compute_map[output][0], (output, storage_map[output][0])
try:
constant = output.type.Constant
except:
except AttributeError:
constant = Constant
msg += [constant(output.type, s[0])]
return msg
rval.append(constant(output.type, storage_map[output][0]))
return rval
register_canonicalize(constant_folding, 'fast_compile')
register_stabilize(constant_folding) # because
register_stabilize(constant_folding)
register_specialize(constant_folding)
def _is_1(expr):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论