提交 efd2d19a authored 作者: Hector Munoz's avatar Hector Munoz 提交者: Ricardo Vieira

Remove Flatten Op

上级 2fee841e
......@@ -44,7 +44,6 @@ from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.shape import (
Shape,
Shape_i,
reshape,
shape,
shape_padaxis,
shape_padleft,
......@@ -2845,163 +2844,6 @@ def vertical_stack(*args):
return concatenate(_args, axis=0)
class Flatten(COp):
"""
Flatten a tensor.
Flattens a tensor to `ndim` dimensions by preserving the leading
ndim - 1 shape components.
.. note:: The interface Flatten(Op) is deprecated, you should use flatten.
"""
view_map = {0: [0]}
check_input = False
__props__ = ("ndim",)
def __init__(self, ndim=1):
warnings.warn(
"Flatten class is deprecated, " "please use flatten method instead.",
DeprecationWarning,
stacklevel=4,
)
self.ndim = int(ndim)
def __str__(self):
return f"{self.__class__.__name__}{{{self.ndim}}}"
def make_node(self, x):
t_x = as_tensor_variable(x)
if self.ndim < 1 or (x.ndim and self.ndim > x.ndim):
raise ValueError(
f"invalid output ndimensions ({self.ndim}) for tensor of "
f"rank {t_x.ndim}"
)
# Infer the broadcastable pattern of the output. For every dimension
# unaffected by the flatten, the broadcast flag should be unchanged.
# For the dimension resulting from the collapse of other dimensions,
# it should be broadcastable iff all the collapsed dimensions were
# broadcastable.
bcast_kept_dims = x.broadcastable[: self.ndim - 1]
bcast_new_dim = builtins.all(x.broadcastable[self.ndim - 1 :])
broadcastable = bcast_kept_dims + (bcast_new_dim,)
return Apply(self, [t_x], [tensor(x.type.dtype, broadcastable)])
def perform(self, node, inp, out_):
(x,) = inp
(out,) = out_
ndim = self.ndim
if ndim == 1:
try:
out[0] = x.reshape(x.size)
except AttributeError:
out[0] = x.reshape((np.prod(x.shape),))
elif ndim == len(x.shape):
out[0] = x
else:
newshape = x.shape[: ndim - 1] + (np.prod(x.shape[ndim - 1 :]),)
out[0] = x.reshape(newshape)
def infer_shape(self, fgraph, node, in_shapes):
from aesara.tensor.math import prod
(in_shp,) = in_shapes
part1 = in_shp[: self.ndim - 1]
part2 = in_shp[self.ndim - 1 :]
if len(part2) > 1:
part2 = (prod(part2, dtype="int64"),)
elif len(part2) == 1:
# We do not want to force an upcast of part2 if its length is 1
pass
else:
if len(in_shp) == 0 and self.ndim == 1:
part2 = (1,)
else:
raise ValueError(
f"invalid output ndimensions ({self.ndim}) for tensor "
f"of rank {len(in_shp)}"
)
out_shape = part1 + part2
return [out_shape]
def grad(self, inp, grads):
(x,) = inp
(g_out,) = grads
return [reshape(g_out, shape(x), x.ndim)]
def R_op(self, inputs, eval_points):
if None in eval_points:
return [None]
return self.make_node(*eval_points).outputs
def c_code_cache_version(self):
return (1, 1)
def c_code(self, node, name, inputs, outputs, sub):
(x,) = inputs
(out,) = outputs
ndim = self.ndim
fail = sub["fail"]
return (
"""
if (%(ndim)s == PyArray_NDIM(%(x)s))
{
Py_XDECREF(%(out)s);
Py_XINCREF(%(x)s);
%(out)s = %(x)s;
}
else
{
Py_XDECREF(%(out)s);
if (%(ndim)s == 1)
{
npy_intp size = PyArray_SIZE(%(x)s);
PyArray_Dims newshape;
newshape.ptr = &size;
newshape.len = 1;
%(out)s = (PyArrayObject*)PyArray_Newshape(%(x)s,
&newshape,
NPY_CORDER);
}
else
{
npy_intp *oldshape = PyArray_DIMS(%(x)s);
npy_intp newshape_dims[%(ndim)s];
int i;
for (i = 0; i < %(ndim)s - 1; ++i)
newshape_dims[i] = oldshape[i];
newshape_dims[i] = 1;
for (int j = %(ndim)s - 1; j < PyArray_NDIM(%(x)s); ++j)
newshape_dims[i] *= oldshape[j];
PyArray_Dims newshape;
newshape.ptr = newshape_dims;
newshape.len = %(ndim)s;
%(out)s = (PyArrayObject*)PyArray_Newshape(%(x)s,
&newshape,
NPY_CORDER);
}
}
if (!%(out)s)
{
//The error message should have been set by
// PyArray_Newshape
%(fail)s;
}
"""
% locals()
)
def is_flat(var, ndim=None, outdim=None):
"""
Verifies the dimensionality of the var is equal to
......
......@@ -45,7 +45,6 @@ from aesara.raise_op import Assert, CheckAndRaise, assert_op
from aesara.tensor.basic import (
Alloc,
AllocEmpty,
Flatten,
Join,
MakeVector,
Rebroadcast,
......@@ -2665,39 +2664,6 @@ def local_useless_split(fgraph, node):
return [out2]
@register_canonicalize
@register_stabilize
@local_optimizer([Flatten])
def local_flatten_lift(fgraph, node):
"""
Flatten(UnaryElemwise(x)) -> UnaryElemwise(Flatten(x))
This optimization is needed by optimization
log1msigm_to_softplus to get applied when there is a flatten.
"""
if (
isinstance(node.op, Flatten)
and node.inputs[0].owner
and isinstance(node.inputs[0].owner.op, Elemwise)
and len(node.inputs[0].owner.inputs) == 1
):
f = node.op(node.inputs[0].owner.inputs[0])
# Copy over stacktrace from previous output node (flatten op),
# since this is the op which may cause an error for f.
copy_stack_trace(node.outputs, f)
e = node.inputs[0].owner.op(f)
# Copy over stacktrace from previous output node and from unary
# elementwise output node since if there was an error, it would
# probably have come from that operation.
copy_stack_trace(node.outputs + [node.inputs[0]], e)
return [e]
def local_reshape_chain(op):
@local_optimizer([op])
def f(fgraph, node):
......
......@@ -50,7 +50,6 @@ list of ops that support R-op:
* Join
* Rebroadcast
* Reshape
* Flatten
* DimShuffle
* Scan [In tests/scan/test_basic.test_rop]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论