提交 b60d4b35 authored 作者: Frederic's avatar Frederic

[ENH] Do not always compile c code when doing constant_folding

上级 fb2066db
...@@ -1193,6 +1193,13 @@ class Elemwise(OpenMPOp): ...@@ -1193,6 +1193,13 @@ class Elemwise(OpenMPOp):
else: else:
return () return ()
def python_constant_folding(self, node):
"""
Return True if we do not want to compile c code
when doing constant folding of this node.
"""
return node.outputs[0].ndim == 0
# def elemwise_to_scal(fgraph): # def elemwise_to_scal(fgraph):
# TODO: why is this commented out? should it be removed? # TODO: why is this commented out? should it be removed?
# it has needed maintenance despite being commented # it has needed maintenance despite being commented
......
...@@ -4508,7 +4508,19 @@ def constant_folding(node): ...@@ -4508,7 +4508,19 @@ def constant_folding(node):
for o in node.outputs: for o in node.outputs:
storage_map[o] = [None] storage_map[o] = [None]
compute_map[o] = [False] compute_map[o] = [False]
if (hasattr(node.op, 'python_constant_folding') and
node.op.python_constant_folding(node)):
old_value = getattr(node.op, '_op_use_c_code', False)
try:
node.op._op_use_c_code = False
thunk = node.op.make_thunk(node,
storage_map,
compute_map,
[])
finally:
node.op._op_use_c_code = old_value
else:
thunk = node.op.make_thunk(node, storage_map, compute_map, thunk = node.op.make_thunk(node, storage_map, compute_map,
no_recycling=[]) no_recycling=[])
......
...@@ -3671,6 +3671,17 @@ def test_constant_folding(): ...@@ -3671,6 +3671,17 @@ def test_constant_folding():
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
assert len(topo) == 2 assert len(topo) == 2
# Test that we do not crash when constant folding elemwise scalar
# as they should not generate c code.
x = tensor.constant(3)
assert x.ndim == 0
mode = theano.compile.get_mode("FAST_COMPILE").excluding("fusion")
f = theano.function([], [x * 2, x + x], mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 2
assert all([isinstance(n.op, DeepCopyOp) for n in topo])
def test_constant_get_stabilized(): def test_constant_get_stabilized():
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论