提交 e89aee9d authored 作者: notoraptor's avatar notoraptor 提交者: GitHub

Merge pull request #6204 from anirudh9119/anirudh_ccw

using op params for Reshape ops
......@@ -1129,7 +1129,7 @@ class GpuReshape(HideC, tensor.Reshape):
context_name=ctx_name)
return Apply(self, [x, shp], [otype()])
def perform(self, node, inp, out_):
def perform(self, node, inp, out_, params):
x, shp = inp
out, = out_
if (len(shp) != self.ndim):
......@@ -1157,33 +1157,33 @@ class GpuReshape(HideC, tensor.Reshape):
out[0] = x.reshape(tuple(shp))
def c_code_cache_version(self):
return (1,)
return (2,)
def c_code(self, node, name, inputs, outputs, sub):
x, shape = inputs
output, = outputs
new_ndim = self.ndim
sdtype = node.inputs[1].type.dtype_specs()[1]
fail = sub['fail']
params = sub['params']
return """
size_t old_size = 1, new_size = 1;
size_t new_dims[%(new_ndim)s];
size_t new_dims[%(params)s->ndim];
int compute_axis = -1;
assert (PyArray_NDIM(%(shape)s) == 1);
if (PyArray_DIM(%(shape)s, 0) != %(new_ndim)s)
if (PyArray_DIM(%(shape)s, 0) != %(params)s->ndim)
{
PyErr_Format(PyExc_ValueError,
"GpuReshape: given shape is of incorrect "
"length (%%d should be %%d).",
PyArray_DIM(%(shape)s, 0), %(new_ndim)s);
PyArray_DIM(%(shape)s, 0), %(params)s->ndim);
%(fail)s;
}
for (size_t i = 0; i < %(x)s->ga.nd; ++i)
old_size *= %(x)s->ga.dimensions[i];
for (size_t i = 0; i < %(new_ndim)s; ++i)
for (size_t i = 0; i < %(params)s->ndim; ++i)
{
new_dims[i] = ((%(sdtype)s*)(
PyArray_BYTES(%(shape)s) +
......@@ -1224,7 +1224,7 @@ class GpuReshape(HideC, tensor.Reshape):
}
Py_XDECREF(%(output)s);
%(output)s = pygpu_reshape(%(x)s, %(new_ndim)s, new_dims,
%(output)s = pygpu_reshape(%(x)s, %(params)s->ndim, new_dims,
GA_C_ORDER, 0, compute_axis);
if (%(output)s == NULL)
{
......
......@@ -29,6 +29,8 @@ from theano import compile, printing
from theano.printing import pprint, min_informative_str
# For history
from theano.compile import Rebroadcast, Shape, shape
from theano.scalar import int32
# We use these exceptions as well.
import theano.scalar.sharedvar
......@@ -4710,20 +4712,19 @@ def vertical_stack(*args):
class Reshape(Op):
"""Perform a reshape operation of the input x to the new shape shp.
The number of dimensions to which to reshape to (ndim) must be
known at graph build time.
"""
view_map = {0: [0]} # output 0 is potentially aliased to inputs [0]
_f16_ok = True
check_input = False
__props__ = ("ndim",)
params_type = ParamsType(ndim=int32)
# name does not participate because it doesn't affect computations
def __init__(self, ndim, name=None):
self.ndim = ndim
self.ndim = int(ndim)
if ndim < 0:
raise ValueError("The output dimensions after reshape must be 0 or greater")
assert name is None, 'name attribute for Reshape has been deprecated'
......@@ -4763,7 +4764,7 @@ class Reshape(Op):
pass
return gof.Apply(self, [x, shp], [tensor(x.type.dtype, bcasts)])
def perform(self, node, inp, out_):
def perform(self, node, inp, out_, params):
x, shp = inp
out, = out_
if (len(shp) != self.ndim):
......@@ -4860,22 +4861,22 @@ class Reshape(Op):
for i in xrange(self.ndim)])]
def c_code_cache_version(self):
return (7,)
return (8,)
def c_code(self, node, name, inputs, outputs, sub):
if isinstance(node.inputs[0], TensorVariable):
x, shp = inputs
z, = outputs
new_ndim = self.ndim
sdtype = node.inputs[1].type.dtype_specs()[1]
fail = sub['fail']
params = sub['params']
return """
assert (PyArray_NDIM(%(shp)s) == 1);
npy_intp new_dims[%(new_ndim)s];
npy_intp new_dims[%(params)s->ndim];
PyArray_Dims newshape;
newshape.ptr = new_dims;
newshape.len = %(new_ndim)s;
for (int ii = 0; ii < %(new_ndim)s; ++ii)
newshape.len = %(params)s->ndim;
for (int ii = 0; ii < %(params)s->ndim; ++ii)
{
// -- We do not want an explicit cast here. the shp can be any
// -- int* dtype. The compiler will explicitly upcast it, but
......@@ -4886,8 +4887,7 @@ class Reshape(Op):
ii * PyArray_STRIDES(%(shp)s)[0]))[0];
}
Py_XDECREF(%(z)s);
%(z)s = (PyArrayObject *) PyArray_Newshape(%(x)s, &newshape,
NPY_CORDER);
%(z)s = (PyArrayObject *) PyArray_Newshape(%(x)s, &newshape, NPY_CORDER);
if (!%(z)s)
{
//The error message should have been set by PyArray_Newshape
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论