提交 fbdcbef5 authored 作者: Sina Honari's avatar Sina Honari

applying Fred's changes

上级 ee70708b
......@@ -4417,7 +4417,7 @@ class Reshape(Op):
if ele == -1:
requ[i] = missing
elif crit == 1: # we reshape to -1
requ = [mul(*ishapes[0])] if ishapes[0] else [1]
requ = [mul(*ishapes[0])] if ishapes[0] else []
elif crit > 1:
raise ValueError('shape argument to Reshape.perform'
' must have at most one entry equal to -1')
......@@ -4657,9 +4657,9 @@ class Flatten(Op):
""" % locals()
def is_flat(node, outdim=1):
def is_flat(var, outdim=1):
"""
Verifies the dimensionality of the node's variable is equal to
Verifies the dimensionality of the var is equal to
outdim. This method is usually called after flatten method on a
variable, where the first outdim-1 dimension size(s) of the variable
is kept intact, and the last dimension size of the variable is made
......@@ -4668,19 +4668,19 @@ def is_flat(node, outdim=1):
Parameters
----------
node : theano.tensor.var.TensorVariable
the theano node on which the dimensionality is checked.
var : theano.tensor.var.TensorVariable
the theano var on which the dimensionality is checked.
outdim : int
the expected dimensionality of node.
the expected dimensionality of var.
Returns
-------
bool
the comparison result of node's dim
the comparison result of var's dim
and the expected outdim.
"""
return node.ndim == outdim
return var.ndim == outdim
def flatten(x, outdim=1):
......@@ -4718,9 +4718,8 @@ def flatten(x, outdim=1):
bcast_kept_dims = x.broadcastable[:outdim - 1]
bcast_new_dim = python_all(x.broadcastable[outdim - 1:])
broadcastable = bcast_kept_dims + (bcast_new_dim,)
for dim, br in enumerate(broadcastable):
if br:
x_reshaped = theano.tensor.addbroadcast(x_reshaped, dim)
x_reshaped = theano.tensor.addbroadcast(
x_reshaped, *filter(lambda i: broadcastable[i], range(outdim)))
return x_reshaped
......
......@@ -5878,19 +5878,27 @@ def test_local_useless_split():
def test_local_flatten_lift():
"""
.. note:: The Flatten(Op) is deprecated, and this method
should be removed with Flatten.
"""
for i in xrange(1, 4):
x = tensor.tensor4()
out = tensor.flatten(T.exp(x), i)
assert out.ndim == i
mode = compile.mode.get_default_mode()
mode = mode.including('local_flatten_lift')
mode = mode.including('local_reshape_lift')
f = theano.function([x], out, mode=mode)
x_np = numpy.random.rand(5, 4, 3, 2).astype(config.floatX)
out_np = f(x_np)
topo = f.maker.fgraph.toposort()
shape_out_np = tuple(x_np.shape[:i-1])+(numpy.prod(x_np.shape[i-1:]),)
assert shape_out_np == out_np.shape
assert tensor.is_flat(topo[-2].outputs[0], outdim=i)
reshape_nodes = filter(
lambda apply_node: isinstance(apply_node.op, tensor.Reshape), topo)
assert (len(reshape_nodes) == 1 and
tensor.is_flat(reshape_nodes[0].outputs[0], outdim=i))
assert isinstance(topo[-1].op, tensor.Elemwise)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论