提交 f14af145 authored 作者: Frederic Bastien's avatar Frederic Bastien

Make an opt use optionaly the shape_feature

上级 5e501473
......@@ -707,18 +707,14 @@ def local_gpua_careduce(node, context_name):
assert reduce_mask[a] == 0
reduce_mask[a] = 1
shape_of = node.fgraph.shape_feature.shape_of
x_shape = shape_of[x]
new_in_shp = [x_shape[0]]
new_in_shp = [shape_i(x, 0)]
new_mask = [reduce_mask[0]]
for i in xrange(1, x.type.ndim):
if reduce_mask[i] == reduce_mask[i - 1]:
new_in_shp[-1] *= x_shape[i]
new_in_shp[-1] *= shape_i(x, i)
else:
new_mask.append(reduce_mask[i])
new_in_shp.append(x_shape[i])
new_in_shp.append(shape_i(x, i))
new_axis = []
for idx, m in enumerate(new_mask):
if m == 1:
......@@ -740,8 +736,12 @@ def local_gpua_careduce(node, context_name):
greduce(gpu_reshaped_x))
if reduce_reshaped_x.ndim != node.outputs[0].ndim:
out_shp = []
for i in range(x.ndim):
if i not in node.op.axis:
out_shp.append(shape_i(x, i))
unreshaped_reduce = reduce_reshaped_x.reshape(
tensor.stack(shape_of[node.outputs[0]]))
tensor.stack(out_shp))
else:
unreshaped_reduce = reduce_reshaped_x
return [unreshaped_reduce]
......
......@@ -14,6 +14,7 @@ from . import dnn
import theano
from theano import scalar as scal
from theano import config, tensor, gof
from theano.compile.ops import shape_i
import theano.ifelse
import theano.tensor.signal.pool
import theano.tensor.nnet
......@@ -899,18 +900,14 @@ def local_gpu_careduce(node):
# to make them a single dimension, do the reduction, and
# then reshape to get them back.
shape_of = node.fgraph.shape_feature.shape_of
x_shape = shape_of[x]
new_in_shp = [x_shape[0]]
new_in_shp = [shape_i(x, 0)]
new_mask = [reduce_mask[0]]
for i in xrange(1, x.type.ndim):
if reduce_mask[i] == reduce_mask[i - 1]:
new_in_shp[-1] *= x_shape[i]
new_in_shp[-1] *= shape_i(x, i)
else:
new_mask.append(reduce_mask[i])
new_in_shp.append(x_shape[i])
new_in_shp.append(shape_i(x, i))
new_greduce = GpuCAReduce(new_mask, scalar_op)
new_x = x.reshape(tensor.stack(new_in_shp))
......@@ -935,8 +932,11 @@ def local_gpu_careduce(node):
# Restore the expected shape of the output
if rval.ndim != out.ndim:
rval = rval.reshape(
tensor.stack(shape_of[out]))
out_shp = []
for i in range(x.ndim):
if i not in node.op.axis:
out_shp.append(shape_i(x, i))
rval = rval.reshape(tensor.stack(out_shp))
if rval.type == out.type:
return [rval]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论