提交 fbdcbef5 authored 作者: Sina Honari's avatar Sina Honari

applying Fred's changes

上级 ee70708b
...@@ -4417,7 +4417,7 @@ class Reshape(Op): ...@@ -4417,7 +4417,7 @@ class Reshape(Op):
if ele == -1: if ele == -1:
requ[i] = missing requ[i] = missing
elif crit == 1: # we reshape to -1 elif crit == 1: # we reshape to -1
requ = [mul(*ishapes[0])] if ishapes[0] else [1] requ = [mul(*ishapes[0])] if ishapes[0] else []
elif crit > 1: elif crit > 1:
raise ValueError('shape argument to Reshape.perform' raise ValueError('shape argument to Reshape.perform'
' must have at most one entry equal to -1') ' must have at most one entry equal to -1')
...@@ -4657,9 +4657,9 @@ class Flatten(Op): ...@@ -4657,9 +4657,9 @@ class Flatten(Op):
""" % locals() """ % locals()
def is_flat(node, outdim=1): def is_flat(var, outdim=1):
""" """
Verifies the dimensionality of the node's variable is equal to Verifies the dimensionality of the var is equal to
outdim. This method is usually called after flatten method on a outdim. This method is usually called after flatten method on a
variable, where the first outdim-1 dimension size(s) of the variable variable, where the first outdim-1 dimension size(s) of the variable
is kept intact, and the last dimension size of the variable is made is kept intact, and the last dimension size of the variable is made
...@@ -4668,19 +4668,19 @@ def is_flat(node, outdim=1): ...@@ -4668,19 +4668,19 @@ def is_flat(node, outdim=1):
Parameters Parameters
---------- ----------
node : theano.tensor.var.TensorVariable var : theano.tensor.var.TensorVariable
the theano node on which the dimensionality is checked. the theano var on which the dimensionality is checked.
outdim : int outdim : int
the expected dimensionality of node. the expected dimensionality of var.
Returns Returns
------- -------
bool bool
the comparison result of node's dim the comparison result of var's dim
and the expected outdim. and the expected outdim.
""" """
return node.ndim == outdim return var.ndim == outdim
def flatten(x, outdim=1): def flatten(x, outdim=1):
...@@ -4718,9 +4718,8 @@ def flatten(x, outdim=1): ...@@ -4718,9 +4718,8 @@ def flatten(x, outdim=1):
bcast_kept_dims = x.broadcastable[:outdim - 1] bcast_kept_dims = x.broadcastable[:outdim - 1]
bcast_new_dim = python_all(x.broadcastable[outdim - 1:]) bcast_new_dim = python_all(x.broadcastable[outdim - 1:])
broadcastable = bcast_kept_dims + (bcast_new_dim,) broadcastable = bcast_kept_dims + (bcast_new_dim,)
for dim, br in enumerate(broadcastable): x_reshaped = theano.tensor.addbroadcast(
if br: x_reshaped, *filter(lambda i: broadcastable[i], range(outdim)))
x_reshaped = theano.tensor.addbroadcast(x_reshaped, dim)
return x_reshaped return x_reshaped
......
...@@ -5878,19 +5878,27 @@ def test_local_useless_split(): ...@@ -5878,19 +5878,27 @@ def test_local_useless_split():
def test_local_flatten_lift(): def test_local_flatten_lift():
"""
.. note:: The Flatten(Op) is deprecated, and this method
should be removed with Flatten.
"""
for i in xrange(1, 4): for i in xrange(1, 4):
x = tensor.tensor4() x = tensor.tensor4()
out = tensor.flatten(T.exp(x), i) out = tensor.flatten(T.exp(x), i)
assert out.ndim == i assert out.ndim == i
mode = compile.mode.get_default_mode() mode = compile.mode.get_default_mode()
mode = mode.including('local_flatten_lift') mode = mode.including('local_reshape_lift')
f = theano.function([x], out, mode=mode) f = theano.function([x], out, mode=mode)
x_np = numpy.random.rand(5, 4, 3, 2).astype(config.floatX) x_np = numpy.random.rand(5, 4, 3, 2).astype(config.floatX)
out_np = f(x_np) out_np = f(x_np)
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
shape_out_np = tuple(x_np.shape[:i-1])+(numpy.prod(x_np.shape[i-1:]),) shape_out_np = tuple(x_np.shape[:i-1])+(numpy.prod(x_np.shape[i-1:]),)
assert shape_out_np == out_np.shape assert shape_out_np == out_np.shape
assert tensor.is_flat(topo[-2].outputs[0], outdim=i)
reshape_nodes = filter(
lambda apply_node: isinstance(apply_node.op, tensor.Reshape), topo)
assert (len(reshape_nodes) == 1 and
tensor.is_flat(reshape_nodes[0].outputs[0], outdim=i))
assert isinstance(topo[-1].op, tensor.Elemwise) assert isinstance(topo[-1].op, tensor.Elemwise)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论