提交 12d2c60c authored 作者: James Bergstra's avatar James Bergstra

corrected Flatten view_map and perform

上级 56cb3f81
......@@ -1942,8 +1942,7 @@ class Flatten(Op):
"""Flattens a tensor to `outdim` dimensions by preserving the leading outdim-1 shape
components.
"""
#Could be done as a reshape, but this is more direct.
#TODO: optimize reshape(x, prod(shape(x))) -> flatten(x)
view_map = {0:[0]}
def __init__(self, outdim=1):
self.outdim = int(outdim)
def __eq__(self, other):
......@@ -1958,9 +1957,9 @@ class Flatten(Op):
def perform(self, node, (x,), (out,)):
outdim = self.outdim
if outdim == 1:
out[0] = x.flatten()
out[0] = x.reshape(x.size)
elif outdim == len(x.shape):
out[0] = x.copy()
out[0] = x
else:
newshape = x.shape[:outdim-1] + (numpy.prod(x.shape[outdim-1:]),)
#print 'newshape', newshape, x.shape, x.shape
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论