提交 eb09a9a2 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

fixed bug in optimization

上级 a3d9f911
...@@ -588,8 +588,6 @@ def local_gpu_careduce(node): ...@@ -588,8 +588,6 @@ def local_gpu_careduce(node):
# and max does not support all combinations of axes # and max does not support all combinations of axes
if node.op.scalar_op in [scal.add, scal.maximum]: if node.op.scalar_op in [scal.add, scal.maximum]:
x, = node.inputs x, = node.inputs
gpu_x = gpu_from_host(x)
gpu_inputs = [ gpu_x ]
if x.owner and x.owner.op == host_from_gpu: if x.owner and x.owner.op == host_from_gpu:
if node.op.axis is None: if node.op.axis is None:
reduce_mask = [1] * x.type.ndim reduce_mask = [1] * x.type.ndim
...@@ -599,8 +597,8 @@ def local_gpu_careduce(node): ...@@ -599,8 +597,8 @@ def local_gpu_careduce(node):
assert reduce_mask[a] == 0 assert reduce_mask[a] == 0
reduce_mask[a] = 1 reduce_mask[a] = 1
greduce = GpuCAReduce(reduce_mask, scalar_op) greduce = GpuCAReduce(reduce_mask, scalar_op)
if greduce.supports_c_code(gpu_inputs): if greduce.supports_c_code([gpu_from_host(x)]):
rval = host_from_gpu(greduce(gpu_x)) rval = host_from_gpu(greduce(gpu_from_host(x)))
if rval.type == node.outputs[0].type: if rval.type == node.outputs[0].type:
return [rval] return [rval]
else: else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论