提交 4bacd641 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Make c_code() defer to perform if broadcast is required.

上级 2813465d
......@@ -734,17 +734,63 @@ class GpuJoin(HideC, Join):
[GpuArrayType(broadcastable=node.outputs[0].broadcastable,
dtype=node.outputs[0].dtype)()])
def _need_broadcast(self, node):
p_broadcastable = node.inputs[1].broadcastable
need_b = [any(t.broadcastable[i] != p_broadcastable[i]
for t in node.inputs[2:])
for i in range(node.inputs[1].ndim)]
try:
# We never need to broadcast on the join axis
axis_v = int(tensor.basic.get_scalar_constant_value(node.inputs[0]))
need_b[axis_v] = False
except tensor.basic.NotScalarConstantError:
pass
return any(need_b)
def perform(self, node, axis_and_tensors, out_):
out, = out_
axis = axis_and_tensors[0]
axis = int(axis_and_tensors[0])
tensors = axis_and_tensors[1:]
out[0] = pygpu.concatenate(tensors, axis=axis).astype(
node.outputs[0].dtype)
if not hasattr(node, '_need_broadcast'):
node._need_broadcast = self._need_broadcast(node)
if node._need_broadcast:
width_sum = 0
template_shape = list(tensors[0].shape)
for t in tensors:
width_sum += t.shape[axis]
tmp_shape = list(t.shape)
tmp_shape[axis] = template_shape[axis]
if tmp_shape != template_shape:
raise ValueError("Shape of input GpuArrays must"
" agree except for the 'axis' dimension")
template_shape[axis] = width_sum
rval = pygpu.zeros(template_shape, dtype=node.outputs[0].dtype)
curpos = 0
def construct_slices(curlen):
slices = [slice(None, None, None) for i in \
range(len(template_shape))]
slices[axis] = slice(curpos, curpos + curlen, None)
return tuple(slices)
for t in tensors:
curlen = t.shape[axis]
rval.__setitem__(construct_slices(curlen), t)
curpos += curlen
out[0] = rval
else:
out[0] = pygpu.concatenate(tensors, axis=axis).astype(
node.outputs[0].dtype)
def c_code_cache_version(self):
return (0,)
return (1,)
def c_code(self, node, name, inputs, out_, sub):
if self._need_broadcast(node):
node._need_broadcast = True
raise MethodNotDefined, 'broadcast not supported'
copy_to_list = []
restype=pygpu.gpuarray.dtype_to_typecode(node.outputs[0].dtype)
for i, inp in enumerate(inputs[1:]):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论