提交 bcbcc1b3 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Don't apply alpha_merge and output_merge when the proc node has more

than one client. Also apply output_merge in the case when the new output has to be broadcasted.
上级 669cf0d1
...@@ -6,6 +6,7 @@ from theano import scalar as scal, Constant ...@@ -6,6 +6,7 @@ from theano import scalar as scal, Constant
from theano.gof import local_optimizer from theano.gof import local_optimizer
from theano.tensor import (DimShuffle, get_scalar_constant_value, from theano.tensor import (DimShuffle, get_scalar_constant_value,
NotScalarConstantError) NotScalarConstantError)
from theano.tensor.opt import broadcast_like
from theano.sandbox.cuda.basic_ops import ( from theano.sandbox.cuda.basic_ops import (
GpuFromHost, HostFromGpu, host_from_gpu, GpuDimShuffle, GpuElemwise) GpuFromHost, HostFromGpu, host_from_gpu, GpuDimShuffle, GpuElemwise)
...@@ -36,7 +37,8 @@ def find_node(v, cls): ...@@ -36,7 +37,8 @@ def find_node(v, cls):
# This digs through possibly redundant transfers to for the node # This digs through possibly redundant transfers to for the node
# that has the op class specified. # that has the op class specified.
if v.owner is not None: if v.owner is not None:
if isinstance(v.owner.op, cls): if (isinstance(v.owner.op, cls) and
len(v.clients) == 1):
return v.owner return v.owner
elif (isinstance(v.owner.op, GpuFromHost) and elif (isinstance(v.owner.op, GpuFromHost) and
v.owner.inputs[0].owner is not None and v.owner.inputs[0].owner is not None and
...@@ -99,9 +101,7 @@ def output_merge(cls, alpha_in, beta_in, out_in, nd): ...@@ -99,9 +101,7 @@ def output_merge(cls, alpha_in, beta_in, out_in, nd):
# other cases are too complex for now # other cases are too complex for now
return None return None
if W.broadcastable != targ.inputs[out_in].broadcastable: if W.broadcastable != targ.inputs[out_in].broadcastable:
# Would need to explicitly tile the output to fill W = broadcast_like(W, targ.inputs[out_in], node.fgraph)
# the full shape here. Disable for now.
return None
inputs = list(targ.inputs) inputs = list(targ.inputs)
inputs[out_in] = W inputs[out_in] = W
inputs[beta_in] = _one.clone() inputs[beta_in] = _one.clone()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论