提交 93540f1d authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Better checks for dimensions or shape mismatch (used in tests).

上级 6c4b2fca
......@@ -179,7 +179,9 @@ class RandomFunction(gof.Op):
r, shape, args = inputs[0], inputs[1], inputs[2:]
assert type(r) == numpy.random.RandomState
r_orig = r
assert self.outtype.ndim == len(shape) + self.ndim_added
if self.outtype.ndim != len(shape) + self.ndim_added:
raise ValueError('Shape mismatch: self.outtype.ndim (%i) != len(shape) (%i) + self.ndim_added (%i)'\
%(self.outtype.ndim, len(shape), self.ndim_added))
if not self.inplace:
r = copy(r)
rout[0] = r
......@@ -189,6 +191,9 @@ class RandomFunction(gof.Op):
out[0] = numpy.asarray(rval, dtype = node.outputs[1].type.dtype)
else:
out[0] = rval
if len(rval.shape) != self.outtype.ndim:
raise ValueError('Shape mismatch: "out" should have dimension %i, but the value produced by "perform" has dimension %i'\
% (self.outtype.ndim, len(rval.shape)))
def grad(self, inputs, outputs):
return [None for i in inputs]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论