提交 9b6f8491 authored 作者: Frederic's avatar Frederic

Err message remove, opt. Disable merge of GpuAllocEmpty.

上级 fc9ce753
...@@ -655,6 +655,16 @@ class PureOp(object): ...@@ -655,6 +655,16 @@ class PureOp(object):
""" """
return True return True
def do_merge(self, node):
"""This allow to disable the merge of ops in the graph.
This is very rarely a good idea to disable it. Do not use if
you do not understand this small comment. You probably do not
need it.
"""
return True
class Op(utils.object2, PureOp, CLinkerOp): class Op(utils.object2, PureOp, CLinkerOp):
"""Convenience class to bundle `PureOp` and `CLinkerOp`""" """Convenience class to bundle `PureOp` and `CLinkerOp`"""
......
...@@ -517,6 +517,8 @@ class MergeFeature(object): ...@@ -517,6 +517,8 @@ class MergeFeature(object):
"""Check if a node can be merged, and queue that replacement.""" """Check if a node can be merged, and queue that replacement."""
if node in self.nodes_seen: if node in self.nodes_seen:
return return
if not node.op.do_merge(node):
return
# These asserts ensure that the fgraph has set the clients field # These asserts ensure that the fgraph has set the clients field
# properly. # properly.
......
...@@ -3288,6 +3288,9 @@ class GpuAllocEmpty(GpuOp): ...@@ -3288,6 +3288,9 @@ class GpuAllocEmpty(GpuOp):
# XXX: We could implement and call CudaNdarray.empty(sh) instead. # XXX: We could implement and call CudaNdarray.empty(sh) instead.
out[0] = cuda_ndarray.cuda_ndarray.CudaNdarray.zeros(sh) out[0] = cuda_ndarray.cuda_ndarray.CudaNdarray.zeros(sh)
def do_merge(self, node):
return False
def c_code(self, node, name, inputs, out_, sub): def c_code(self, node, name, inputs, out_, sub):
out, = out_ out, = out_
fail = sub['fail'] fail = sub['fail']
...@@ -3340,6 +3343,9 @@ class GpuAlloc(GpuAllocEmpty): ...@@ -3340,6 +3343,9 @@ class GpuAlloc(GpuAllocEmpty):
""" """
__props__ = ('memset_0',) __props__ = ('memset_0',)
def do_merge(self, node):
return True
def __init__(self, memset_0=False): def __init__(self, memset_0=False):
self.memset_0 = memset_0 self.memset_0 = memset_0
......
...@@ -372,6 +372,26 @@ def test_reshape(): ...@@ -372,6 +372,26 @@ def test_reshape():
pass pass
def test_alloc_empty():
# Test that we allocated correctly
f = theano.function([], cuda.basic_ops.gpu_alloc_empty(2, 3))
assert len(f.maker.fgraph.apply_nodes) == 1
out = f()
assert out.shape == (2, 3)
assert out.dtype == 'float32'
# Test that we do not merge them.
f = theano.function([], [cuda.basic_ops.gpu_alloc_empty(2, 3),
cuda.basic_ops.gpu_alloc_empty(2, 3)])
out = f()
assert out[0].shape == (2, 3)
assert out[0].dtype == 'float32'
assert out[1].shape == (2, 3)
assert out[1].dtype == 'float32'
assert len([node for node in f.maker.fgraph.apply_nodes
if isinstance(node.op, cuda.basic_ops.GpuAllocEmpty)]) == 2
def test_elemwise_empty(): def test_elemwise_empty():
# test with 0 element # test with 0 element
a = tcn.shared_constructor(theano._asarray(numpy.random.rand(0, 0), a = tcn.shared_constructor(theano._asarray(numpy.random.rand(0, 0),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论