提交 f68ffb17 authored 作者: AndroidCloud's avatar AndroidCloud

using op params for Reshape ops

上级 f39ba80a
......@@ -29,6 +29,7 @@ from theano import compile, printing
from theano.printing import pprint, min_informative_str
# For history
from theano.compile import Rebroadcast, Shape, shape
from theano.scalar import Scalar
# We use these exceptions as well.
import theano.scalar.sharedvar
......@@ -4708,22 +4709,22 @@ def vertical_stack(*args):
return concatenate(args, axis=0)
class Reshape(Op):
"""Perform a reshape operation of the input x to the new shape shp.
The number of dimensions to which to reshape to (ndim) must be
known at graph build time.
"""
view_map = {0: [0]} # output 0 is potentially aliased to inputs [0]
_f16_ok = True
check_input = False
__props__ = ("ndim",)
params_type = ParamsType(ndim=Scalar('int32'))
# name does not participate because it doesn't affect computations
def __init__(self, ndim, name=None):
self.ndim = ndim
self.ndim = int(ndim)
if ndim < 0:
raise ValueError("The output dimensions after reshape must be 0 or greater")
assert name is None, 'name attribute for Reshape has been deprecated'
......@@ -4731,6 +4732,10 @@ class Reshape(Op):
def __str__(self):
return '%s{%s}' % (self.__class__.__name__, self.ndim)
def get_params(self, node):
return int(self.ndim)
def make_node(self, x, shp):
x = as_tensor_variable(x)
shp_orig = shp
......@@ -4763,7 +4768,7 @@ class Reshape(Op):
pass
return gof.Apply(self, [x, shp], [tensor(x.type.dtype, bcasts)])
def perform(self, node, inp, out_):
def perform(self, node, inp, out_, params):
x, shp = inp
out, = out_
if (len(shp) != self.ndim):
......@@ -4871,11 +4876,11 @@ class Reshape(Op):
fail = sub['fail']
return """
assert (PyArray_NDIM(%(shp)s) == 1);
npy_intp new_dims[%(new_ndim)s];
npy_intp new_dims[%(params)s->ndim];
PyArray_Dims newshape;
newshape.ptr = new_dims;
newshape.len = %(new_ndim)s;
for (int ii = 0; ii < %(new_ndim)s; ++ii)
newshape.len = %(params)s->ndim;
for (int ii = 0; ii < %(params)s->ndim; ++ii)
{
// -- We do not want an explicit cast here. the shp can be any
// -- int* dtype. The compiler will explicitly upcast it, but
......@@ -4886,8 +4891,7 @@ class Reshape(Op):
ii * PyArray_STRIDES(%(shp)s)[0]))[0];
}
Py_XDECREF(%(z)s);
%(z)s = (PyArrayObject *) PyArray_Newshape(%(x)s, &newshape,
NPY_CORDER);
%(z)s = (PyArrayObject *) PyArray_Newshape(%(x)s, &newshape, NPY_CORDER);
if (!%(z)s)
{
//The error message should have been set by PyArray_Newshape
......@@ -4919,6 +4923,8 @@ def reshape(x, newshape, ndim=None):
return rval
class Flatten(Op):
"""
Flatten a tensor.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论