提交 93540f1d authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Better checks for dimensions or shape mismatch (used in tests).

上级 6c4b2fca
...@@ -179,7 +179,9 @@ class RandomFunction(gof.Op): ...@@ -179,7 +179,9 @@ class RandomFunction(gof.Op):
r, shape, args = inputs[0], inputs[1], inputs[2:] r, shape, args = inputs[0], inputs[1], inputs[2:]
assert type(r) == numpy.random.RandomState assert type(r) == numpy.random.RandomState
r_orig = r r_orig = r
assert self.outtype.ndim == len(shape) + self.ndim_added if self.outtype.ndim != len(shape) + self.ndim_added:
raise ValueError('Shape mismatch: self.outtype.ndim (%i) != len(shape) (%i) + self.ndim_added (%i)'\
%(self.outtype.ndim, len(shape), self.ndim_added))
if not self.inplace: if not self.inplace:
r = copy(r) r = copy(r)
rout[0] = r rout[0] = r
...@@ -189,6 +191,9 @@ class RandomFunction(gof.Op): ...@@ -189,6 +191,9 @@ class RandomFunction(gof.Op):
out[0] = numpy.asarray(rval, dtype = node.outputs[1].type.dtype) out[0] = numpy.asarray(rval, dtype = node.outputs[1].type.dtype)
else: else:
out[0] = rval out[0] = rval
if len(rval.shape) != self.outtype.ndim:
raise ValueError('Shape mismatch: "out" should have dimension %i, but the value produced by "perform" has dimension %i'\
% (self.outtype.ndim, len(rval.shape)))
def grad(self, inputs, outputs): def grad(self, inputs, outputs):
return [None for i in inputs] return [None for i in inputs]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论