提交 12d2c60c authored 作者: James Bergstra's avatar James Bergstra

corrected Flatten view_map and perform

上级 56cb3f81
...@@ -1942,8 +1942,7 @@ class Flatten(Op): ...@@ -1942,8 +1942,7 @@ class Flatten(Op):
"""Flattens a tensor to `outdim` dimensions by preserving the leading outdim-1 shape """Flattens a tensor to `outdim` dimensions by preserving the leading outdim-1 shape
components. components.
""" """
#Could be done as a reshape, but this is more direct. view_map = {0:[0]}
#TODO: optimize reshape(x, prod(shape(x))) -> flatten(x)
def __init__(self, outdim=1): def __init__(self, outdim=1):
self.outdim = int(outdim) self.outdim = int(outdim)
def __eq__(self, other): def __eq__(self, other):
...@@ -1958,9 +1957,9 @@ class Flatten(Op): ...@@ -1958,9 +1957,9 @@ class Flatten(Op):
def perform(self, node, (x,), (out,)): def perform(self, node, (x,), (out,)):
outdim = self.outdim outdim = self.outdim
if outdim == 1: if outdim == 1:
out[0] = x.flatten() out[0] = x.reshape(x.size)
elif outdim == len(x.shape): elif outdim == len(x.shape):
out[0] = x.copy() out[0] = x
else: else:
newshape = x.shape[:outdim-1] + (numpy.prod(x.shape[outdim-1:]),) newshape = x.shape[:outdim-1] + (numpy.prod(x.shape[outdim-1:]),)
#print 'newshape', newshape, x.shape, x.shape #print 'newshape', newshape, x.shape, x.shape
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论