提交 8f174da2 authored 作者: James Bergstra's avatar James Bergstra

fixed flatten to work for scalars

上级 0d302be5
...@@ -1819,7 +1819,7 @@ class Flatten(Op): ...@@ -1819,7 +1819,7 @@ class Flatten(Op):
return hash(type(self))^hash(self.outdim) return hash(type(self))^hash(self.outdim)
def make_node(self, x): def make_node(self, x):
t_x = as_tensor(x) t_x = as_tensor(x)
if self.outdim < 1 or self.outdim > x.ndim: if self.outdim < 1 or (x.ndim and self.outdim > x.ndim):
raise ValueError('invalid output ndimensions(%i) for tensor of rank %i' %(self.outdim, t_x.ndim)) raise ValueError('invalid output ndimensions(%i) for tensor of rank %i' %(self.outdim, t_x.ndim))
return gof.Apply(self, [t_x], [tensor(x.type.dtype, (False,)*self.outdim)]) return gof.Apply(self, [t_x], [tensor(x.type.dtype, (False,)*self.outdim)])
def perform(self, node, (x,), (out,)): def perform(self, node, (x,), (out,)):
......
...@@ -1631,6 +1631,18 @@ def test_flatten_outdimNone(): ...@@ -1631,6 +1631,18 @@ def test_flatten_outdimNone():
tensor.verify_grad(None, Flatten(), [a_val]) tensor.verify_grad(None, Flatten(), [a_val])
def test_flatten_scalar():
a = dscalar()
c = flatten(a)
f = function([a], c, mode='FAST_COMPILE')
a_val = numpy.asarray(3.0, dtype='float64')
c_val = numpy.asarray([3.0], dtype='float64')
assert numpy.all(f(a_val)==c_val)
f = function([a], c, mode='FAST_RUN')
assert numpy.all(f(a_val)==c_val)
#tensor.verify_grad(None, Flatten(), [a_val]) #TODO: fix verify_grd to work on scalars
def test_flatten_outdim1(): def test_flatten_outdim1():
a = dmatrix() a = dmatrix()
c = flatten(a, 1) c = flatten(a, 1)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论