提交 354d7b57 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #2629 from abergeron/fix_merge_opts

Don't apply alpha_merge and output_merge when the proc node has more than one client.
...@@ -32,14 +32,15 @@ def grab_cpu_scalar(v, nd): ...@@ -32,14 +32,15 @@ def grab_cpu_scalar(v, nd):
return v.dimshuffle(()) return v.dimshuffle(())
def find_node(v, cls): def find_node(v, cls, ignore_clients=False):
# This digs through possibly redundant transfers to for the node # This digs through possibly redundant transfers to for the node
# that has the op class specified. # that has the op class specified.
if v.owner is not None: if v.owner is not None and (ignore_clients or len(v.clients) == 1):
if isinstance(v.owner.op, cls): if isinstance(v.owner.op, cls):
return v.owner return v.owner
elif (isinstance(v.owner.op, GpuFromHost) and elif (isinstance(v.owner.op, GpuFromHost) and
v.owner.inputs[0].owner is not None and v.owner.inputs[0].owner is not None and
(ignore_clients or len(v.owner.inputs[0].clients) == 1) and
isinstance(v.owner.inputs[0].owner.op, HostFromGpu)): isinstance(v.owner.inputs[0].owner.op, HostFromGpu)):
return find_node(v.owner.inputs[0].owner.inputs[0], cls) return find_node(v.owner.inputs[0].owner.inputs[0], cls)
else: else:
...@@ -111,8 +112,8 @@ def output_merge(cls, alpha_in, beta_in, out_in, nd): ...@@ -111,8 +112,8 @@ def output_merge(cls, alpha_in, beta_in, out_in, nd):
# other cases are too complex for now # other cases are too complex for now
return None return None
if W.broadcastable != targ.inputs[out_in].broadcastable: if W.broadcastable != targ.inputs[out_in].broadcastable:
# Would need to explicitly tile the output to fill # May change later to do the broadcast, but it's
# the full shape here. Disable for now. # under discussion.
return None return None
inputs = list(targ.inputs) inputs = list(targ.inputs)
inputs[out_in] = W inputs[out_in] = W
......
...@@ -659,6 +659,55 @@ def test_dnn_conv_alpha_output_merge(): ...@@ -659,6 +659,55 @@ def test_dnn_conv_alpha_output_merge():
utt.assert_allclose(v1, v2) utt.assert_allclose(v1, v2)
def test_dnn_conv_merge_mouts():
# make sure it doesn't attempt to output/alpha merge a convolution
# that has multiple clients.
if not cuda.dnn.dnn_available():
raise SkipTest(cuda.dnn.dnn_available.msg)
img = T.ftensor4()
kern = T.ftensor4()
out = T.ftensor4()
conv = dnn.dnn_conv(img, kern)
lr = numpy.asarray(0.05, dtype='float32')
if cuda.dnn.version() == -1:
# Can't merge alpha with cudnn v1
fr = conv + out
else:
fr = lr * (conv + out)
rr = conv * lr
f = theano.function([img, kern, out], [fr, rr], mode=mode_with_gpu)
convs = [n for n in f.maker.fgraph.toposort()
if isinstance(n.op, dnn.GpuDnnConv)]
assert len(convs) == 1
def test_dnn_conv_merge_broad():
# Make sure that we don't apply output_merge on broadcasted values.
if not cuda.dnn.dnn_available():
raise SkipTest(cuda.dnn.dnn_available.msg)
img = T.ftensor4()
kern = T.ftensor4()
conv = dnn.dnn_conv(img, kern)
lr = numpy.asarray(0.05, dtype='float32')
# this does broadcasting
fr = conv + lr
f = theano.function([img, kern], [fr])
convs = [n for n in f.maker.fgraph.toposort()
if isinstance(n.op, dnn.GpuDnnConv)]
assert len(convs) == 1
conv = convs[0]
# Assert output was not merged
assert isinstance(conv.inputs[2].owner.op, GpuAllocEmpty)
def test_dnn_conv_grad(): def test_dnn_conv_grad():
if not cuda.dnn.dnn_available() or dnn.version() == -1: if not cuda.dnn.dnn_available() or dnn.version() == -1:
raise SkipTest('alpha != 1.0 not supported in cudnn v1') raise SkipTest('alpha != 1.0 not supported in cudnn v1')
......
...@@ -42,6 +42,7 @@ def find_node(v, cls, ignore_clients=False): ...@@ -42,6 +42,7 @@ def find_node(v, cls, ignore_clients=False):
return v.owner return v.owner
elif (isinstance(v.owner.op, GpuFromHost) and elif (isinstance(v.owner.op, GpuFromHost) and
v.owner.inputs[0].owner is not None and v.owner.inputs[0].owner is not None and
(ignore_clients or len(v.owner.inputs[0].clients) == 1) and
isinstance(v.owner.inputs[0].owner.op, HostFromGpu)): isinstance(v.owner.inputs[0].owner.op, HostFromGpu)):
return find_node(v.owner.inputs[0].owner.inputs[0], cls) return find_node(v.owner.inputs[0].owner.inputs[0], cls)
else: else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论