提交 b60d4b35 authored 作者: Frederic's avatar Frederic

[ENH] Do not always compile c code when doing constant_folding

上级 fb2066db
......@@ -1193,6 +1193,13 @@ class Elemwise(OpenMPOp):
else:
return ()
def python_constant_folding(self, node):
"""
Return True if we do not want to compile c code
when doing constant folding of this node.
"""
return node.outputs[0].ndim == 0
# def elemwise_to_scal(fgraph):
# TODO: why is this commented out? should it be removed?
# it has needed maintenance despite being commented
......
......@@ -4508,9 +4508,21 @@ def constant_folding(node):
for o in node.outputs:
storage_map[o] = [None]
compute_map[o] = [False]
if (hasattr(node.op, 'python_constant_folding') and
node.op.python_constant_folding(node)):
thunk = node.op.make_thunk(node, storage_map, compute_map,
no_recycling=[])
old_value = getattr(node.op, '_op_use_c_code', False)
try:
node.op._op_use_c_code = False
thunk = node.op.make_thunk(node,
storage_map,
compute_map,
[])
finally:
node.op._op_use_c_code = old_value
else:
thunk = node.op.make_thunk(node, storage_map, compute_map,
no_recycling=[])
required = thunk()
assert not required # a node whose inputs are all provided should always
......
......@@ -3671,6 +3671,17 @@ def test_constant_folding():
topo = f.maker.fgraph.toposort()
assert len(topo) == 2
# Test that we do not crash when constant folding elemwise scalar
# as they should not generate c code.
x = tensor.constant(3)
assert x.ndim == 0
mode = theano.compile.get_mode("FAST_COMPILE").excluding("fusion")
f = theano.function([], [x * 2, x + x], mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 2
assert all([isinstance(n.op, DeepCopyOp) for n in topo])
def test_constant_get_stabilized():
"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论