提交 28fa7b76 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Remove unused inplace option in DimShuffle

上级 ff5df65f
...@@ -79,12 +79,7 @@ def jax_funcify_DimShuffle(op, **kwargs): ...@@ -79,12 +79,7 @@ def jax_funcify_DimShuffle(op, **kwargs):
for augm in op.augment: for augm in op.augment:
shape.insert(augm, 1) shape.insert(augm, 1)
res = jnp.reshape(res, shape) return jnp.reshape(res, shape)
if not op.inplace:
res = jnp.copy(res)
return res
return dimshuffle return dimshuffle
......
...@@ -414,7 +414,6 @@ def numba_funcify_DimShuffle(op, node, **kwargs): ...@@ -414,7 +414,6 @@ def numba_funcify_DimShuffle(op, node, **kwargs):
shuffle = tuple(op.shuffle) shuffle = tuple(op.shuffle)
transposition = tuple(op.transposition) transposition = tuple(op.transposition)
augment = tuple(op.augment) augment = tuple(op.augment)
inplace = op.inplace
ndim_new_shape = len(shuffle) + len(augment) ndim_new_shape = len(shuffle) + len(augment)
...@@ -474,12 +473,7 @@ def numba_funcify_DimShuffle(op, node, **kwargs): ...@@ -474,12 +473,7 @@ def numba_funcify_DimShuffle(op, node, **kwargs):
new_shape = find_shape(shuffle_shape) new_shape = find_shape(shuffle_shape)
# FIXME: Numba's `array.reshape` only accepts C arrays. # FIXME: Numba's `array.reshape` only accepts C arrays.
res_reshape = np.reshape(np.ascontiguousarray(x), new_shape) return np.reshape(np.ascontiguousarray(x), new_shape)
if not inplace:
return res_reshape.copy()
else:
return res_reshape
else: else:
......
...@@ -61,12 +61,7 @@ def pytorch_funcify_DimShuffle(op, **kwargs): ...@@ -61,12 +61,7 @@ def pytorch_funcify_DimShuffle(op, **kwargs):
for augm in op.augment: for augm in op.augment:
shape.insert(augm, 1) shape.insert(augm, 1)
res = torch.reshape(res, shape) return torch.reshape(res, shape)
if not op.inplace:
res = res.clone()
return res
return dimshuffle return dimshuffle
......
...@@ -7,10 +7,6 @@ int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res, PA ...@@ -7,10 +7,6 @@ int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res, PA
npy_intp* dimensions; npy_intp* dimensions;
npy_intp* strides; npy_intp* strides;
// This points to either the original input or a copy we create below.
// Either way, this is what we should be working on/with.
PyArrayObject *_input;
if (!PyArray_IS_C_CONTIGUOUS(params->_new_order)) { if (!PyArray_IS_C_CONTIGUOUS(params->_new_order)) {
PyErr_SetString(PyExc_RuntimeError, "DimShuffle: param _new_order must be C-contiguous."); PyErr_SetString(PyExc_RuntimeError, "DimShuffle: param _new_order must be C-contiguous.");
return 1; return 1;
...@@ -20,7 +16,7 @@ int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res, PA ...@@ -20,7 +16,7 @@ int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res, PA
nd_out = PyArray_SIZE(params->_new_order); nd_out = PyArray_SIZE(params->_new_order);
if (PyArray_NDIM(input) != nd_in) { if (PyArray_NDIM(input) != nd_in) {
PyErr_SetString(PyExc_NotImplementedError, "DimShuffle: Input has less dimensions than expected."); PyErr_SetString(PyExc_ValueError, "DimShuffle: Input has less dimensions than expected.");
return 1; return 1;
} }
...@@ -34,12 +30,12 @@ int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res, PA ...@@ -34,12 +30,12 @@ int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res, PA
return 1; return 1;
}; };
npy_intp original_size = PyArray_SIZE(_input); npy_intp original_size = PyArray_SIZE(input);
npy_intp new_size = 1; npy_intp new_size = 1;
for (npy_intp i = 0; i < nd_out; ++i) { for (npy_intp i = 0; i < nd_out; ++i) {
if (new_order[i] != -1) { if (new_order[i] != -1) {
dimensions[i] = PyArray_DIMS(_input)[new_order[i]]; dimensions[i] = PyArray_DIMS(input)[new_order[i]];
strides[i] = PyArray_DIMS(_input)[new_order[i]] == 1 ? 0 : PyArray_STRIDES(_input)[new_order[i]]; strides[i] = PyArray_DIMS(input)[new_order[i]] == 1 ? 0 : PyArray_STRIDES(input)[new_order[i]];
} else { } else {
dimensions[i] = 1; dimensions[i] = 1;
strides[i] = 0; strides[i] = 0;
...@@ -57,22 +53,13 @@ int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res, PA ...@@ -57,22 +53,13 @@ int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res, PA
if (*res) if (*res)
Py_XDECREF(*res); Py_XDECREF(*res);
if (params->inplace) {
_input = input;
Py_INCREF((PyObject*)_input);
} else {
_input = (PyArrayObject *)PyArray_FromAny(
(PyObject *)input, NULL, 0, 0, NPY_ARRAY_ALIGNED | NPY_ARRAY_ENSURECOPY,
NULL);
}
// Create the new array. // Create the new array.
*res = (PyArrayObject*)PyArray_New(&PyArray_Type, nd_out, dimensions, *res = (PyArrayObject*)PyArray_New(&PyArray_Type, nd_out, dimensions,
PyArray_TYPE(_input), strides, PyArray_TYPE(input), strides,
PyArray_DATA(_input), PyArray_ITEMSIZE(_input), PyArray_DATA(input), PyArray_ITEMSIZE(input),
// borrow only the writable flag from the base // borrow only the writable flag from the base
// the NPY_OWNDATA flag will default to 0. // the NPY_OWNDATA flag will default to 0.
(NPY_ARRAY_WRITEABLE * PyArray_ISWRITEABLE(_input)), (NPY_ARRAY_WRITEABLE * PyArray_ISWRITEABLE(input)),
NULL); NULL);
if (*res == NULL) { if (*res == NULL) {
...@@ -81,12 +68,13 @@ int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res, PA ...@@ -81,12 +68,13 @@ int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res, PA
return 1; return 1;
} }
// Declare it a view of the original input
Py_INCREF((PyObject*)input);
PyArray_SetBaseObject(*res, (PyObject*)input);
// recalculate flags: CONTIGUOUS, FORTRAN, ALIGNED // recalculate flags: CONTIGUOUS, FORTRAN, ALIGNED
PyArray_UpdateFlags(*res, NPY_ARRAY_UPDATE_ALL); PyArray_UpdateFlags(*res, NPY_ARRAY_UPDATE_ALL);
// we are making a view in both inplace and non-inplace cases
PyArray_SetBaseObject(*res, (PyObject*)_input);
free(strides); free(strides);
free(dimensions); free(dimensions);
return 0; return 0;
......
...@@ -19,7 +19,6 @@ from pytensor.misc.frozendict import frozendict ...@@ -19,7 +19,6 @@ from pytensor.misc.frozendict import frozendict
from pytensor.npy_2_compat import normalize_axis_tuple from pytensor.npy_2_compat import normalize_axis_tuple
from pytensor.printing import Printer, pprint from pytensor.printing import Printer, pprint
from pytensor.scalar import get_scalar_type from pytensor.scalar import get_scalar_type
from pytensor.scalar.basic import bool as scalar_bool
from pytensor.scalar.basic import identity as scalar_identity from pytensor.scalar.basic import identity as scalar_identity
from pytensor.scalar.basic import int64, transfer_type, upcast from pytensor.scalar.basic import int64, transfer_type, upcast
from pytensor.tensor import elemwise_cgen as cgen from pytensor.tensor import elemwise_cgen as cgen
...@@ -114,15 +113,15 @@ class DimShuffle(ExternalCOp): ...@@ -114,15 +113,15 @@ class DimShuffle(ExternalCOp):
_f16_ok = True _f16_ok = True
check_input = False check_input = False
__props__ = ("input_ndim", "new_order", "inplace") __props__ = ("input_ndim", "new_order")
c_func_file = "c_code/dimshuffle.c" c_func_file = "c_code/dimshuffle.c"
c_func_name = "APPLY_SPECIFIC(cpu_dimshuffle)" c_func_name = "APPLY_SPECIFIC(cpu_dimshuffle)"
view_map = {0: [0]}
@property @property
def params_type(self): def params_type(self):
return ParamsType( return ParamsType(
_new_order=lvector, _new_order=lvector,
inplace=scalar_bool,
input_ndim=int64, input_ndim=int64,
) )
...@@ -135,7 +134,6 @@ class DimShuffle(ExternalCOp): ...@@ -135,7 +134,6 @@ class DimShuffle(ExternalCOp):
self.input_ndim = input_ndim self.input_ndim = input_ndim
self.new_order = tuple(new_order) self.new_order = tuple(new_order)
self._new_order = [(-1 if x == "x" else x) for x in self.new_order] self._new_order = [(-1 if x == "x" else x) for x in self.new_order]
self.inplace = True
for i, j in enumerate(new_order): for i, j in enumerate(new_order):
if j != "x": if j != "x":
...@@ -178,9 +176,6 @@ class DimShuffle(ExternalCOp): ...@@ -178,9 +176,6 @@ class DimShuffle(ExternalCOp):
:input_ndim :input_ndim
] == list(range(input_ndim)) ] == list(range(input_ndim))
if self.inplace:
self.view_map = {0: [0]}
def __setstate__(self, state): def __setstate__(self, state):
self.__dict__.update(state) self.__dict__.update(state)
if not hasattr(self, "func_files"): if not hasattr(self, "func_files"):
...@@ -248,12 +243,7 @@ class DimShuffle(ExternalCOp): ...@@ -248,12 +243,7 @@ class DimShuffle(ExternalCOp):
new_shape = list(res.shape[: len(self.shuffle)]) new_shape = list(res.shape[: len(self.shuffle)])
for augm in self.augment: for augm in self.augment:
new_shape.insert(augm, 1) new_shape.insert(augm, 1)
res = res.reshape(new_shape) out[0][0] = res.reshape(new_shape)
if not self.inplace:
res = np.copy(res)
out[0][0] = res
def infer_shape(self, fgraph, node, shapes): def infer_shape(self, fgraph, node, shapes):
(ishp,) = shapes (ishp,) = shapes
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论