提交 f68ffb17 authored 作者: AndroidCloud's avatar AndroidCloud

using op params for Reshape ops

上级 f39ba80a
...@@ -29,6 +29,7 @@ from theano import compile, printing ...@@ -29,6 +29,7 @@ from theano import compile, printing
from theano.printing import pprint, min_informative_str from theano.printing import pprint, min_informative_str
# For history # For history
from theano.compile import Rebroadcast, Shape, shape from theano.compile import Rebroadcast, Shape, shape
from theano.scalar import Scalar
# We use these exceptions as well. # We use these exceptions as well.
import theano.scalar.sharedvar import theano.scalar.sharedvar
...@@ -4708,22 +4709,22 @@ def vertical_stack(*args): ...@@ -4708,22 +4709,22 @@ def vertical_stack(*args):
return concatenate(args, axis=0) return concatenate(args, axis=0)
class Reshape(Op): class Reshape(Op):
"""Perform a reshape operation of the input x to the new shape shp. """Perform a reshape operation of the input x to the new shape shp.
The number of dimensions to which to reshape to (ndim) must be The number of dimensions to which to reshape to (ndim) must be
known at graph build time. known at graph build time.
""" """
view_map = {0: [0]} # output 0 is potentially aliased to inputs [0] view_map = {0: [0]} # output 0 is potentially aliased to inputs [0]
_f16_ok = True _f16_ok = True
check_input = False check_input = False
__props__ = ("ndim",) __props__ = ("ndim",)
params_type = ParamsType(ndim=Scalar('int32'))
# name does not participate because it doesn't affect computations # name does not participate because it doesn't affect computations
def __init__(self, ndim, name=None): def __init__(self, ndim, name=None):
self.ndim = ndim self.ndim = int(ndim)
if ndim < 0: if ndim < 0:
raise ValueError("The output dimensions after reshape must be 0 or greater") raise ValueError("The output dimensions after reshape must be 0 or greater")
assert name is None, 'name attribute for Reshape has been deprecated' assert name is None, 'name attribute for Reshape has been deprecated'
...@@ -4731,6 +4732,10 @@ class Reshape(Op): ...@@ -4731,6 +4732,10 @@ class Reshape(Op):
def __str__(self): def __str__(self):
return '%s{%s}' % (self.__class__.__name__, self.ndim) return '%s{%s}' % (self.__class__.__name__, self.ndim)
def get_params(self, node):
return int(self.ndim)
def make_node(self, x, shp): def make_node(self, x, shp):
x = as_tensor_variable(x) x = as_tensor_variable(x)
shp_orig = shp shp_orig = shp
...@@ -4763,7 +4768,7 @@ class Reshape(Op): ...@@ -4763,7 +4768,7 @@ class Reshape(Op):
pass pass
return gof.Apply(self, [x, shp], [tensor(x.type.dtype, bcasts)]) return gof.Apply(self, [x, shp], [tensor(x.type.dtype, bcasts)])
def perform(self, node, inp, out_): def perform(self, node, inp, out_, params):
x, shp = inp x, shp = inp
out, = out_ out, = out_
if (len(shp) != self.ndim): if (len(shp) != self.ndim):
...@@ -4871,11 +4876,11 @@ class Reshape(Op): ...@@ -4871,11 +4876,11 @@ class Reshape(Op):
fail = sub['fail'] fail = sub['fail']
return """ return """
assert (PyArray_NDIM(%(shp)s) == 1); assert (PyArray_NDIM(%(shp)s) == 1);
npy_intp new_dims[%(new_ndim)s]; npy_intp new_dims[%(params)s->ndim];
PyArray_Dims newshape; PyArray_Dims newshape;
newshape.ptr = new_dims; newshape.ptr = new_dims;
newshape.len = %(new_ndim)s; newshape.len = %(params)s->ndim;
for (int ii = 0; ii < %(new_ndim)s; ++ii) for (int ii = 0; ii < %(params)s->ndim; ++ii)
{ {
// -- We do not want an explicit cast here. the shp can be any // -- We do not want an explicit cast here. the shp can be any
// -- int* dtype. The compiler will explicitly upcast it, but // -- int* dtype. The compiler will explicitly upcast it, but
...@@ -4886,8 +4891,7 @@ class Reshape(Op): ...@@ -4886,8 +4891,7 @@ class Reshape(Op):
ii * PyArray_STRIDES(%(shp)s)[0]))[0]; ii * PyArray_STRIDES(%(shp)s)[0]))[0];
} }
Py_XDECREF(%(z)s); Py_XDECREF(%(z)s);
%(z)s = (PyArrayObject *) PyArray_Newshape(%(x)s, &newshape, %(z)s = (PyArrayObject *) PyArray_Newshape(%(x)s, &newshape, NPY_CORDER);
NPY_CORDER);
if (!%(z)s) if (!%(z)s)
{ {
//The error message should have been set by PyArray_Newshape //The error message should have been set by PyArray_Newshape
...@@ -4919,6 +4923,8 @@ def reshape(x, newshape, ndim=None): ...@@ -4919,6 +4923,8 @@ def reshape(x, newshape, ndim=None):
return rval return rval
class Flatten(Op): class Flatten(Op):
""" """
Flatten a tensor. Flatten a tensor.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论