提交 71ca5da7 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Integration of the 'None' shape in RandomFunction

上级 fe11af02
......@@ -133,7 +133,7 @@ class RandomFunction(gof.Op):
"""
if shape == () or shape == []:
shape = tensor.lvector()
shape = tensor.as_tensor_variable(shape, dtype='int64')
else:
shape = tensor.as_tensor_variable(shape, ndim=1)
assert shape.type.ndim == 1
......@@ -161,22 +161,48 @@ class RandomFunction(gof.Op):
r, shape, args = inputs[0], inputs[1], inputs[2:]
assert type(r) == numpy.random.RandomState
r_orig = r
if self.outtype.ndim != len(shape) + self.ndim_added:
# If shape == [], that means numpy will compute the correct shape,
# numpy uses shape "None" to represent that. Else, numpy expects a tuple.
# TODO: compute the appropriate shape?
if len(shape) == 0:
shape = None
else:
shape = tuple(shape)
if shape is not None and self.outtype.ndim != len(shape) + self.ndim_added:
raise ValueError('Shape mismatch: self.outtype.ndim (%i) != len(shape) (%i) + self.ndim_added (%i)'\
%(self.outtype.ndim, len(shape), self.ndim_added))
if not self.inplace:
r = copy(r)
rout[0] = r
rval = self.fn(r, *(args + [tuple(shape)]))
rval = self.fn(r, *(args + [shape]))
if not isinstance(rval, numpy.ndarray) \
or str(rval.dtype) != node.outputs[1].type.dtype:
out[0] = theano._asarray(rval, dtype = node.outputs[1].type.dtype)
else:
out[0] = rval
rval = theano._asarray(rval, dtype = node.outputs[1].type.dtype)
# When shape is None, numpy has a tendency to unexpectedly
# return a scalar instead of a higher-dimension array containing
# only one element. This value should be reshaped
if shape is None and rval.ndim == 0 and self.outtype.ndim > 0:
rval = rval.reshape([1]*self.outtype.ndim)
if len(rval.shape) != self.outtype.ndim:
raise ValueError('Shape mismatch: "out" should have dimension %i, but the value produced by "perform" has dimension %i'\
% (self.outtype.ndim, len(rval.shape)))
# Check the output has the right shape
if shape is not None:
if self.ndim_added == 0 and shape != rval.shape:
raise ValueError('Shape mismatch: "out" should have shape %s, but the value produced by "perform" has shape %s'\
% (shape, rval.shape))
elif self.ndim_added > 0 and shape != rval.shape[:-self.ndim_added]:
raise ValueError('Shape mismatch: "out" should have shape starting with %s (plus %i extra dimensions), but the value produced by "perform" has shape %s'\
% (shape, self.ndim_added, rval.shape))
out[0] = rval
def grad(self, inputs, outputs):
return [None for i in inputs]
......@@ -314,6 +340,9 @@ def permutation_helper(random_state, n, shape):
# is a long, the numpy permutation function will crash on Windows.
n = int(n.item())
if shape is None:
# Draw only one permutation, equivalent to shape = ()
shape = ()
out_shape = list(shape)
out_shape.append(n)
out = numpy.zeros(out_shape, int)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论