提交 072b031f authored 作者: Sina Honari's avatar Sina Honari

correcting the bugs

上级 9b71642c
......@@ -6,6 +6,7 @@ import os
import shutil
import stat
import sys
import warnings
import theano
from theano.compat import get_unbound_function
......
......@@ -4656,13 +4656,27 @@ class Flatten(Op):
}
""" % locals()
def is_flat(node, outdim=1):
"""
Checks whether node's op is an instance of Reshape
and verifies the dimensionality of variable is correct.
"""
return isinstance(node.op, theano.tensor.Reshape) and\
node.inputs[1].ndim == outdim
if not isinstance(node.op, theano.tensor.Reshape):
return False
new_shape = node.inputs[1]
# If new shape is defined by `MakeVector`, then number of inputs to
# `MakeVector` must be equal to `outdim`.
if new_shape.owner and \
isinstance(new_shape.owner.op, theano.tensor.opt.MakeVector):
return new_shape.ndim == 1 and len(new_shape.owner.inputs) == outdim
# `TensorConstant`
elif isinstance(new_shape, theano.tensor.TensorConstant):
return new_shape.ndim == 1 and new_shape.data.shape[0] == outdim
else:
raise NotImplemented()
def flatten(x, outdim=1):
......@@ -4677,10 +4691,11 @@ def flatten(x, outdim=1):
# even if it's a scalar. Otherwise, outdim must be positive
# and smaller than x.ndim.
if outdim < 1 or (outdim > 1 and outdim > x.ndim):
raise ValueError('outdim %s out of bound [1, %d)' % (outdim, x.ndim+1))
raise ValueError('outdim %s out of bound [1, %d)'
% (outdim, x.ndim + 1))
if outdim > 1:
dims = tuple(x.shape[:outdim-1]) + (-1,)
dims = tuple(x.shape[:outdim - 1]) + (-1,)
else:
dims = (-1,)
x_reshaped = x.reshape(dims)
......
......@@ -5890,7 +5890,7 @@ def test_local_flatten_lift():
topo = f.maker.fgraph.toposort()
shape_out_np = tuple(x_np.shape[:i-1])+(numpy.prod(x_np.shape[i-1:]),)
assert shape_out_np == out_np.shape
assert tensor.is_flat(topo[0], outdim=i)
assert tensor.is_flat(topo[-2], outdim=i)
assert isinstance(topo[-1].op, tensor.Elemwise)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论