提交 d1bef012 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Do not loop over broadcastable dimensions

上级 b14af85f
......@@ -1067,22 +1067,44 @@ def _get_preallocated_maps(node, thunk, prealloc_modes, def_val,
# We assume that the different outputs of a same Op will behave
# independently, and there is no need to test over all combinations
# of outputs (the time taken is prohibitive).
# When all outputs on a certain dimension are broadcastable, the Op
# can assume that the shape is 1 on that dimension, and stride testing
# is less relevant.
max_ndim = 0
out_broadcast_pattern = [True] * max_ndim
for r in node.outputs:
if isinstance(r.type, (TensorType, CudaNdarrayType)):
max_ndim = max(max_ndim, r.ndim)
if max_ndim < r.ndim:
out_broadcast_pattern += [True] * (r.ndim - max_ndim)
max_ndim = r.ndim
assert len(out_broadcast_pattern) == max_ndim
for i, b in enumerate(r.broadcastable):
out_broadcast_pattern[i] = out_broadcast_pattern[i] and b
if 'strided' in prealloc_modes or 'ALL' in prealloc_modes:
# Initial allocation
init_strided = {}
for r in node.outputs:
if isinstance(r.type, (TensorType, CudaNdarrayType)):
# Create a buffer twice as large in every dimension
new_buf = r.type.value_zeros(
[(s * 2) for s in r_vals[r].shape])
# Create a buffer twice as large in every dimension,
# except if broadcastable
buf_shape = []
for s, b in zip(r_vals[r].shape, r.broadcastable):
if b:
buf_shape.append(s)
else:
buf_shape.append(s * 2)
new_buf = r.type.value_zeros(buf_shape)
init_strided[r] = new_buf
for step_signs in itertools_product((-1, 1), repeat=max_ndim):
step_signs_list = []
for b in out_broadcast_pattern:
if b:
step_signs_list.append((1,))
else:
step_signs_list.append((-1, 1))
for step_signs in itertools_product(*step_signs_list):
for step_size in (1, 2):
strided = {}
steps = [s * step_size for s in step_signs]
......@@ -1111,7 +1133,11 @@ def _get_preallocated_maps(node, thunk, prealloc_modes, def_val,
if 'wrong_size' in prealloc_modes or 'ALL' in prealloc_modes:
# For each dimension, try size-1, size, size+1
for dim in xrange(max_ndim):
for dim, b in enumerate(out_broadcast_pattern):
if b:
# The shape has to be 1
continue
shape_diff = [0] * max_ndim
for diff in (-1, 1):
shape_diff[dim] = diff
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论