提交 23040281 authored 作者: John Salvatier's avatar John Salvatier

fixed changes nouiz suggested

import os, sys
import os
import sys
from theano.compat import PY3
from theano.gof.compilelock import get_lock, release_lock
......@@ -14,11 +15,13 @@ if os.path.exists(os.path.join(config.compiledir, 'cutils_ext.so')):
def compile_cutils():
"""Do just the compilation of cutils_ext"""
types = ['npy_' + t for t in ['int8', 'int16', 'int32', 'int64', 'int128',
'int256', 'uint8', 'uint16', 'uint32', 'uint64', 'uint128', 'uint256',
'float16', 'float32', 'float64', 'float80', 'float96', 'float128',
'float256']]
types = ['npy_'+t for t in ['int8', 'int16', 'int32', 'int64', 'int128', 'int256', 'uint8', 'uint16', 'uint32', 'uint64', 'uint128', 'uint256', 'float16', 'float32', 'float64', 'float80', 'float96', 'float128', 'float256'] ]
complex_types = ['npy_'+t for t in ['complex32', 'complex64', 'complex128', 'complex160', 'complex192', 'complex512'] ]
complex_types = ['npy_' + t for t in ['complex32', 'complex64',
'complex128', 'complex160', 'complex192', 'complex512']]
inplace_map_template = """
#if defined(%(typen)s)
......@@ -37,33 +40,36 @@ def compile_cutils():
floatadd = "((%(type)s*)mit->dataptr)[0] = ((%(type)s*)mit->dataptr)[0] + ((%(type)s*)it->dataptr)[0];"
complexadd = """
((%(type)s*)mit->dataptr)[0].real = ((%(type)s*)mit->dataptr)[0].real + ((%(type)s*)it->dataptr)[0].real;
((%(type)s*)mit->dataptr)[0].imag = ((%(type)s*)mit->dataptr)[0].imag + ((%(type)s*)it->dataptr)[0].imag;
((%(type)s*)mit->dataptr)[0].real = ((%(type)s*)mit->dataptr)[0].real + ((%(type)s*)it->dataptr)[0].real;
((%(type)s*)mit->dataptr)[0].imag = ((%(type)s*)mit->dataptr)[0].imag + ((%(type)s*)it->dataptr)[0].imag;
"""
fns = ''.join([inplace_map_template % {'type': t, 'typen': t.upper(),
'op': floatadd % {'type': t}}
for t in types] +
[inplace_map_template % {'type': t, 'typen': t.upper(),
'op': complexadd % {'type': t}}
for t in complex_types])
fns = ''.join([inplace_map_template % {'type' : t, 'typen' : t.upper(), 'op' : floatadd % {'type' : t} } for t in types] +
[inplace_map_template % {'type' : t, 'typen' : t.upper(), 'op' : complexadd % {'type' : t} } for t in complex_types])
fn_array = ("inplace_map_binop addition_funcs[] = {" +
fn_array = ("inplace_map_binop addition_funcs[] = {" +
''.join(["""
#if defined(%(typen)s)
%(type)s_inplace_add,
#endif
""" % {'type' : t, 'typen' : t.upper()} for t in types+complex_types]) +
""" % {'type': t, 'typen': t.upper()}
for t in types + complex_types]) +
"""NULL};
""")
type_number_array = ("int type_numbers[] = {" +
type_number_array = ("int type_numbers[] = {" +
''.join(["""
#if defined(%(typen)s)
%(typen)s,
#endif
""" % {'type' : t, 'typen' : t.upper()} for t in types+complex_types]) +
""" % {'type': t, 'typen': t.upper()}
for t in types + complex_types]) +
"-1000};")
code = ("""
#include <Python.h>
#include "numpy/arrayobject.h"
......@@ -91,7 +97,7 @@ def compile_cutils():
#if NPY_API_VERSION >= 0x00000008
typedef void (*inplace_map_binop)(PyArrayMapIterObject *, PyArrayIterObject *);
""" + fns + fn_array + type_number_array +
""" + fns + fn_array + type_number_array +
"""
static int
......@@ -111,7 +117,7 @@ map_increment(PyArrayMapIterObject *mit, PyObject *op, inplace_map_binop add_inp
return -1;
}
if ((mit->subspace != NULL) && (mit->consec)) {
if (mit->iteraxes[0] > 0) {
if (mit->iteraxes[0] > 0) {
PyArray_MapIterSwapAxes(mit, (PyArrayObject **)&arr, 0);
if (arr == NULL) {
return -1;
......@@ -121,8 +127,7 @@ map_increment(PyArrayMapIterObject *mit, PyObject *op, inplace_map_binop add_inp
it = (PyArrayIterObject*)
PyArray_BroadcastToShape((PyObject*)arr, mit->dimensions, mit->nd);
if (it == NULL) {
Py_DECREF(arr);
Py_DECREF(arr);
return -1;
}
......@@ -139,13 +144,13 @@ inplace_increment(PyObject *dummy, PyObject *args)
{
PyObject *arg_a = NULL, *index=NULL, *inc=NULL;
PyArrayObject *a;
inplace_map_binop add_inplace = NULL;
inplace_map_binop add_inplace = NULL;
int type_number = -1;
int i =0;
PyArrayMapIterObject * mit;
if (!PyArg_ParseTuple(args, "OOO", &arg_a, &index,
&inc)) {
&inc)) {
return NULL;
}
if (!PyArray_Check(arg_a)) {
......@@ -154,29 +159,29 @@ inplace_increment(PyObject *dummy, PyObject *args)
}
a = (PyArrayObject *) arg_a;
if (PyArray_FailUnlessWriteable(a, "input/output array") < 0) {
return NULL;
}
}
if (PyArray_NDIM(a) == 0) {
PyErr_SetString(PyExc_IndexError, "0-d arrays can't be indexed.");
return NULL;
return NULL;
}
type_number = PyArray_TYPE(a);
type_number = PyArray_TYPE(a);
while (type_numbers[i] >= 0 && addition_funcs[i] != NULL){
while (type_numbers[i] >= 0 && addition_funcs[i] != NULL){
if (type_number == type_numbers[i]) {
add_inplace = addition_funcs[i];
break;
}
i++ ;
}
if (add_inplace == NULL) {
PyErr_SetString(PyExc_TypeError, "unsupported type for a");
PyErr_SetString(PyExc_TypeError, "unsupported type for a");
return NULL;
}
mit = (PyArrayMapIterObject *) PyArray_MapIterArray(a, index);
......@@ -186,9 +191,9 @@ inplace_increment(PyObject *dummy, PyObject *args)
if (map_increment(mit, inc, add_inplace) != 0) {
goto fail;
}
Py_DECREF(mit);
Py_INCREF(Py_None);
return Py_None;
......@@ -204,17 +209,16 @@ fail:
{"run_cthunk", run_cthunk, METH_VARARGS|METH_KEYWORDS,
"Run a theano cthunk."},
#if NPY_API_VERSION >= 0x00000008
{"inplace_increment", inplace_increment,
{"inplace_increment", inplace_increment,
METH_VARARGS,
"increments a numpy array inplace at the passed indexes."},
#endif
{NULL, NULL, 0, NULL} /* Sentinel */
};""")
if PY3:
# This is not the most efficient code, but it is written this way to highlight
# the changes needed to make 2.x code compile under python 3.
# This is not the most efficient code, but it is written this way to
# highlight the changes needed to make 2.x code compile under python 3.
code = code.replace("<Python.h>", '"numpy/npy_3kcompat.h"', 1)
code = code.replace("PyCObject", "NpyCapsule")
code += """
......@@ -243,7 +247,6 @@ fail:
} //extern C
"""
loc = os.path.join(config.compiledir, 'cutils_ext')
if not os.path.exists(loc):
os.mkdir(loc)
......
......@@ -23,7 +23,8 @@ import numpy
import theano
from theano.compat import PY3
from theano import gof
from theano.gof import Op, utils, Variable, Constant, Type, Apply, FunctionGraph
from theano.gof import (Op, utils, Variable, Constant, Type, Apply,
FunctionGraph)
from theano.gof.python25 import partial, all, any
from theano.configparser import config
......@@ -2680,7 +2681,7 @@ class Composite(ScalarOp):
except AttributeError:
if 0:
l = []
for n in fgraph.toposort():
for n in self.fgraph.toposort():
if hasattr(n.op, "name") and n.op.name is not None:
v = n.op.name
if v.startswith("Composite"):
......
......@@ -24,9 +24,9 @@ from theano import compile, printing
from theano.printing import pprint, min_informative_str
from theano.tensor.utils import hash_from_ndarray
import theano.gof.cutils #needed to import cutils_ext
import theano.gof.cutils # needed to import cutils_ext
try:
from cutils_ext.cutils_ext import inplace_increment
from cutils_ext.cutils_ext import inplace_increment
except ImportError:
inplace_increment = None
......@@ -1752,16 +1752,29 @@ class _tensor_py_operators:
# standard indexing is used; if it fails with
# AdvancedIndexingError, advanced indexing
advanced = False
for arg in args:
axis = None
for i, arg in enumerate(args):
try:
arg == numpy.newaxis or Subtensor.convert(arg)
except AdvancedIndexingError:
advanced = True
break
if advanced:
axis = None
break
else:
advanced = True
axis = i
if advanced:
if (len(args) == 1 and as_tensor_variable(args[0]).ndim <= 1):
return advanced_subtensor1(self, *args)
if (axis is not None
and numpy.all(a == slice(None) for a in args[:axis])
and numpy.all(a == slice(None) for a in args[axis + 1:])
and isinstance(args[axis], (
numpy.ndarray,
list,
TensorVariable,
TensorConstant,
theano.tensor.sharedvar.TensorSharedVariable))):
return self.take(arg, axis)
else:
return AdvancedSubtensor()(self, *args)
else:
......@@ -4439,7 +4452,8 @@ class Subtensor(Op):
slice_c = None
return slice(slice_a, slice_b, slice_c)
# There is a bug in numpy that results in isinstance(x, int) returning False for numpy integers.
# There is a bug in numpy that results in isinstance(x, int) returning
# False for numpy integers.
# See <http://projects.scipy.org/numpy/ticket/2235>.
elif isinstance(entry, (numpy.integer, int)):
return entry
......@@ -7198,19 +7212,21 @@ def as_index_variable(idx):
raise TypeError('index must be integers')
return idx
def as_int_none_variable(x):
if x is None:
return NoneConst
x = as_tensor_variable(x, ndim = 0)
x = as_tensor_variable(x, ndim=0)
if x.type.dtype[:3] not in ('int', 'uin'):
raise TypeError('index must be integers')
return x
class MakeSlice(Op):
def make_node(self, slc):
return Apply(self,
map(as_int_none_variable,[slc.start, slc.stop, slc.step]),
map(as_int_none_variable,
[slc.start, slc.stop, slc.step]),
[slicetype()])
def perform(self, node, inp, out_):
......@@ -7218,7 +7234,7 @@ class MakeSlice(Op):
out[0] = slice(*inp)
def __str__(self):
return self.__class__.__name__
return self.__class__.__name__
def __eq__(self, other):
return type(self) == type(other)
......@@ -7226,11 +7242,11 @@ class MakeSlice(Op):
def __hash__(self):
return hash(type(self))
def grad(self, inputs, grads):
return [DiconnectedType()() for i in inputs]
def grad(self, inputs, grads):
return [DisconnectedType()() for i in inputs]
make_slice = MakeSlice()
class SliceType(gof.Type):
......@@ -7245,7 +7261,6 @@ class SliceType(gof.Type):
slicetype = SliceType()
class NoneTypeT(gof.Type):
def filter(self, x, strict=False, allow_downcast=None):
......@@ -7257,13 +7272,16 @@ class NoneTypeT(gof.Type):
def __str__(self):
return "None"
NoneConst = Constant(NoneTypeT(), None, name = 'None')
NoneConst = Constant(NoneTypeT(), None, name='None')
def adv_index_broadcastable_pattern(a, idx):
"""
This function is only used to determine the broardcast pattern for AdvancedSubtensor output variable.
This function is only used to determine the broadcast pattern for
AdvancedSubtensor output variable.
For this, we make a fake ndarray and a fake idx and call use ask numpy the output. From this, we find the output broadcast pattern.
For this, we make a fake ndarray and a fake idx and call use ask numpy
the output. From this, we find the output broadcast pattern.
"""
def replace_slice(v):
......@@ -7274,21 +7292,22 @@ def adv_index_broadcastable_pattern(a, idx):
" to be fetched.", v)
else:
v = v.outputs[0]
if NoneConst.equals(v):
return None
if isinstance(v.type, SliceType):
return slice(None,None)
return numpy.zeros( (2,)* v.ndim, int)
if isinstance(v.type, SliceType):
return slice(None, None)
return numpy.zeros((2,) * v.ndim, int)
newidx = tuple(map(replace_slice, idx))
#2 - True = 1; 2 - False = 2
fakeshape = [2 - bc for bc in a.broadcastable]
fakeshape = [2 - bc for bc in a.broadcastable]
retshape = numpy.empty(fakeshape)[newidx].shape
return tuple([dim == 1 for dim in retshape])
class AdvancedSubtensor(Op):
"""Return a subtensor copy, using advanced indexing.
"""
......@@ -7309,13 +7328,11 @@ class AdvancedSubtensor(Op):
x = as_tensor_variable(x)
index = tuple(map(as_index_variable, index))
bcast = adv_index_broadcastable_pattern(x, index)
return gof.Apply(self,
(x,) + index,
[tensor(dtype = x.type.dtype,
broadcastable = adv_index_broadcastable_pattern(x, index) )])
(x,) + index,
[tensor(dtype=x.type.dtype,
broadcastable=bcast)])
def R_op(self, inputs, eval_points):
if eval_points[0] is None:
......@@ -7392,11 +7409,6 @@ class AdvancedIncSubtensor(Op):
self.allow_legacy_perform = False
@classmethod
@property
def increment_available():
return inplace_increment is not None
def __hash__(self):
return hash((type(self), self.inplace, self.set_instead_of_inc))
......@@ -7417,7 +7429,7 @@ class AdvancedIncSubtensor(Op):
op = self
# If we are incrementing, but the increment compiled function is not
# available, we need to support legacy cases.
if not self.set_instead_of_inc and not self.increment_available:
if not self.set_instead_of_inc and inplace_increment is None:
legacy_conditions = False
if x.ndim == 2 and y.ndim == 1 and len(inputs) == 2:
ind1 = as_tensor_variable(inputs[0])
......
......@@ -43,7 +43,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as,
ScalarFromTensor, TensorFromScalar, dtensor4, Rebroadcast, Alloc,
dtensor3, SpecifyShape, Mean, IncSubtensor, AdvancedIncSubtensor1,
itensor3, Tile, AdvancedIncSubtensor, switch, Diagonal, Diag,
nonzero, flatnonzero, nonzero_values)
nonzero, flatnonzero, nonzero_values, inplace_increment)
from theano.tests import unittest_tools as utt
......@@ -3721,9 +3721,7 @@ class TestIncSubtensor1(unittest.TestCase):
self.assertRaises(TypeError,
lambda: inc_subtensor(self.v[self.adv1q], fmatrix()))
def check_increment_available():
if not AdvancedIncSubtensor.increment_available:
raise SkipTest("inc_subtensor with advanced indexing not enabled. "
inplace_increment_missing = SkipTest("inc_subtensor with advanced indexing not enabled. "
"Installing NumPy 1.8 or the latest development version "
"should make that feature available.")
......@@ -3755,7 +3753,8 @@ class TestAdvancedSubtensor(unittest.TestCase):
a.broadcastable, self.ix2.broadcastable)
def test_inc_adv_subtensor_w_matrix(self):
check_increment_available()
if inplace_increment is None:
raise inplace_increment_missing
a = inc_subtensor(self.v[self.ix2], self.v[self.ix2])
......@@ -3766,7 +3765,8 @@ class TestAdvancedSubtensor(unittest.TestCase):
assert numpy.allclose(aval, [.4, .9 * 3, .1 * 3])
def test_inc_adv_subtensor_w_2vec(self):
check_increment_available()
if inplace_increment is None:
raise inplace_increment_missing
subt = self.m[self.ix1, self.ix12]
a = inc_subtensor(subt, subt)
......@@ -3786,7 +3786,8 @@ class TestAdvancedSubtensor(unittest.TestCase):
[.5, .3 * 2, .15]]), aval
def test_inc_adv_subtensor_with_broadcasting(self):
check_increment_available()
if inplace_increment is None:
raise inplace_increment_missing
a = inc_subtensor(self.m[self.ix1, self.ix12], 2.1)
......@@ -3804,7 +3805,8 @@ class TestAdvancedSubtensor(unittest.TestCase):
[.5, .3 + 2.1, .15]]), aval
def test_inc_adv_subtensor_with_index_broadcasting(self):
check_increment_available()
if inplace_increment is None:
raise inplace_increment_missing
a = inc_subtensor(self.m[self.ix1, self.ix2], 2.1)
......@@ -7441,6 +7443,8 @@ class TestTensorInstanceMethods(unittest.TestCase):
self.assertRaises(TypeError, X.take, [0.0])
indices = [[1,0,1], [0,1,1]]
assert_array_equal(X.take(indices, 1).eval({X: x}), x.take(indices, 1))
# Test equivalent advanced indexing
assert_array_equal(X[:,indices].eval({X: x}), x[:,indices])
if __name__ == '__main__':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论