提交 c7103613 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Do not resize output memory in elemwise.

Also use a saner way to compute the output dimensions.
上级 6d0936fc
...@@ -742,43 +742,45 @@ class Elemwise(Op): ...@@ -742,43 +742,45 @@ class Elemwise(Op):
raise ValueError('\n'.join(msg_chunks)) raise ValueError('\n'.join(msg_chunks))
else: else:
raise ValueError(base_exc_str) raise ValueError(base_exc_str)
# Other mismatches will be caught by the ufunc
# Determine the shape of outputs
out_shape = []
for values in zip(*[input.shape for input in inputs]):
if numpy.prod(values) == 0:
# All non-broadcasted dimensions should be zero
assert max(values) <= 1
out_shape.append(0)
else:
out_shape.append(max(values))
out_shape = tuple(out_shape)
if not self.inplace_pattern: if not self.inplace_pattern:
for output, storage in zip(node.outputs, output_storage): for output, storage in zip(node.outputs, output_storage):
odat = storage[0] odat = storage[0]
shape = [max(values)
for values in zip(*[input.shape for input in inputs])]
if odat is not None: if odat is not None:
if odat.shape != shape: if odat.shape != out_shape:
try: # It is unsafe to try to resize odat,
# reuse storage if we can # we have to allocate output storage.
odat.resize(shape) odat = None
except ValueError:
# odat cannot be resized, we have to allocate one
odat = None
if odat is None: if odat is None:
odat = numpy.ndarray(shape, dtype=output.type.dtype) odat = numpy.ndarray(out_shape, dtype=output.type.dtype)
storage[0] = odat storage[0] = odat
else: else:
for i, (output, storage) in enumerate(zip(node.outputs, for i, (output, storage) in enumerate(
output_storage)): zip(node.outputs, output_storage)):
#i is an output idx #i is an output idx
if i in self.inplace_pattern: if i in self.inplace_pattern:
odat = inputs[self.inplace_pattern[i]] odat = inputs[self.inplace_pattern[i]]
else: else:
odat = storage[0] odat = storage[0]
shape = [max(values)
for values in zip(*[input.shape
for input in inputs])]
if odat is not None: if odat is not None:
if odat.shape != shape: if odat.shape != out_shape:
try: # It is unsafe to try to resize odat,
odat.resize(shape) # we have to allocate output storage.
except ValueError: odat = None
odat = None
if odat is None: if odat is None:
odat = numpy.ndarray(shape, dtype=output.type.dtype) odat = numpy.ndarray(out_shape,
dtype=output.type.dtype)
storage[0] = odat storage[0] = odat
ufunc_args = inputs # + output_storage ufunc_args = inputs # + output_storage
...@@ -834,21 +836,16 @@ class Elemwise(Op): ...@@ -834,21 +836,16 @@ class Elemwise(Op):
# always return an ndarray with dtype object # always return an ndarray with dtype object
variable = numpy.asarray(variable, dtype=nout.dtype) variable = numpy.asarray(variable, dtype=nout.dtype)
if (hasattr(variable, 'shape') # The storage has been resized earlier.
and storage[0].shape != variable.shape): if hasattr(variable, 'shape'):
if numpy.prod(variable.shape) == 0: assert storage[0].shape == variable.shape
# numpy don't resize from a shape (1,5) to (0,5)
# This bypass the inplace...
# But I it is important in this case.
storage[0] = variable
continue
storage[0].resize(variable.shape)
if storage[0].shape:
storage[0][:] = variable
else: else:
storage[0].itemset(variable) # If variable has not shape, then it is a scalar.
assert numpy.prod(storage[0].shape) == 1
storage[0][...] = variable
assert str(storage[0].dtype) != 'object' assert str(storage[0].dtype) != 'object'
# the following should be used instead of the previous loop, # the following should be used instead of the previous loop,
# unfortunately it tends to segfault # unfortunately it tends to segfault
# self.ufunc(*(ufunc_args+[s[0] for s in output_storage])) # self.ufunc(*(ufunc_args+[s[0] for s in output_storage]))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论