提交 072b031f authored 作者: Sina Honari's avatar Sina Honari

correcting the bugs

上级 9b71642c
...@@ -6,6 +6,7 @@ import os ...@@ -6,6 +6,7 @@ import os
import shutil import shutil
import stat import stat
import sys import sys
import warnings
import theano import theano
from theano.compat import get_unbound_function from theano.compat import get_unbound_function
......
...@@ -4656,13 +4656,27 @@ class Flatten(Op): ...@@ -4656,13 +4656,27 @@ class Flatten(Op):
} }
""" % locals() """ % locals()
def is_flat(node, outdim=1): def is_flat(node, outdim=1):
""" """
Checks whether node's op is an instance of Reshape Checks whether node's op is an instance of Reshape
and verifies the dimensionality of variable is correct. and verifies the dimensionality of variable is correct.
""" """
return isinstance(node.op, theano.tensor.Reshape) and\ if not isinstance(node.op, theano.tensor.Reshape):
node.inputs[1].ndim == outdim return False
new_shape = node.inputs[1]
# If new shape is defined by `MakeVector`, then number of inputs to
# `MakeVector` must be equal to `outdim`.
if new_shape.owner and \
isinstance(new_shape.owner.op, theano.tensor.opt.MakeVector):
return new_shape.ndim == 1 and len(new_shape.owner.inputs) == outdim
# `TensorConstant`
elif isinstance(new_shape, theano.tensor.TensorConstant):
return new_shape.ndim == 1 and new_shape.data.shape[0] == outdim
else:
raise NotImplemented()
def flatten(x, outdim=1): def flatten(x, outdim=1):
...@@ -4677,10 +4691,11 @@ def flatten(x, outdim=1): ...@@ -4677,10 +4691,11 @@ def flatten(x, outdim=1):
# even if it's a scalar. Otherwise, outdim must be positive # even if it's a scalar. Otherwise, outdim must be positive
# and smaller than x.ndim. # and smaller than x.ndim.
if outdim < 1 or (outdim > 1 and outdim > x.ndim): if outdim < 1 or (outdim > 1 and outdim > x.ndim):
raise ValueError('outdim %s out of bound [1, %d)' % (outdim, x.ndim+1)) raise ValueError('outdim %s out of bound [1, %d)'
% (outdim, x.ndim + 1))
if outdim > 1: if outdim > 1:
dims = tuple(x.shape[:outdim-1]) + (-1,) dims = tuple(x.shape[:outdim - 1]) + (-1,)
else: else:
dims = (-1,) dims = (-1,)
x_reshaped = x.reshape(dims) x_reshaped = x.reshape(dims)
......
...@@ -5890,7 +5890,7 @@ def test_local_flatten_lift(): ...@@ -5890,7 +5890,7 @@ def test_local_flatten_lift():
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
shape_out_np = tuple(x_np.shape[:i-1])+(numpy.prod(x_np.shape[i-1:]),) shape_out_np = tuple(x_np.shape[:i-1])+(numpy.prod(x_np.shape[i-1:]),)
assert shape_out_np == out_np.shape assert shape_out_np == out_np.shape
assert tensor.is_flat(topo[0], outdim=i) assert tensor.is_flat(topo[-2], outdim=i)
assert isinstance(topo[-1].op, tensor.Elemwise) assert isinstance(topo[-1].op, tensor.Elemwise)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论