提交 b68811d1 authored 作者: john salvatier's avatar john salvatier

Merge pull request #4 from lamblin/advinc_rebase3

Fix unit tests
import os, sys import os
import sys
from theano.compat import PY3 from theano.compat import PY3
from theano.gof.compilelock import get_lock, release_lock from theano.gof.compilelock import get_lock, release_lock
...@@ -13,11 +14,13 @@ if os.path.exists(os.path.join(config.compiledir, 'cutils_ext.so')): ...@@ -13,11 +14,13 @@ if os.path.exists(os.path.join(config.compiledir, 'cutils_ext.so')):
def compile_cutils(): def compile_cutils():
"""Do just the compilation of cutils_ext""" """Do just the compilation of cutils_ext"""
types = ['npy_' + t for t in ['int8', 'int16', 'int32', 'int64', 'int128',
'int256', 'uint8', 'uint16', 'uint32', 'uint64', 'uint128', 'uint256',
'float16', 'float32', 'float64', 'float80', 'float96', 'float128',
'float256']]
complex_types = ['npy_' + t for t in ['complex32', 'complex64',
types = ['npy_'+t for t in ['int8', 'int16', 'int32', 'int64', 'int128', 'int256', 'uint8', 'uint16', 'uint32', 'uint64', 'uint128', 'uint256', 'float16', 'float32', 'float64', 'float80', 'float96', 'float128', 'float256'] ] 'complex128', 'complex160', 'complex192', 'complex512']]
complex_types = ['npy_'+t for t in ['complex32', 'complex64', 'complex128', 'complex160', 'complex192', 'complex512'] ]
inplace_map_template = """ inplace_map_template = """
#if defined(%(typen)s) #if defined(%(typen)s)
...@@ -40,16 +43,20 @@ def compile_cutils(): ...@@ -40,16 +43,20 @@ def compile_cutils():
((%(type)s*)mit->dataptr)[0].imag = ((%(type)s*)mit->dataptr)[0].imag + ((%(type)s*)it->dataptr)[0].imag; ((%(type)s*)mit->dataptr)[0].imag = ((%(type)s*)mit->dataptr)[0].imag + ((%(type)s*)it->dataptr)[0].imag;
""" """
fns = ''.join([inplace_map_template % {'type': t, 'typen': t.upper(),
fns = ''.join([inplace_map_template % {'type' : t, 'typen' : t.upper(), 'op' : floatadd % {'type' : t} } for t in types] + 'op': floatadd % {'type': t}}
[inplace_map_template % {'type' : t, 'typen' : t.upper(), 'op' : complexadd % {'type' : t} } for t in complex_types]) for t in types] +
[inplace_map_template % {'type': t, 'typen': t.upper(),
'op': complexadd % {'type': t}}
for t in complex_types])
fn_array = ("inplace_map_binop addition_funcs[] = {" + fn_array = ("inplace_map_binop addition_funcs[] = {" +
''.join([""" ''.join(["""
#if defined(%(typen)s) #if defined(%(typen)s)
%(type)s_inplace_add, %(type)s_inplace_add,
#endif #endif
""" % {'type' : t, 'typen' : t.upper()} for t in types+complex_types]) + """ % {'type': t, 'typen': t.upper()}
for t in types + complex_types]) +
"""NULL}; """NULL};
""") """)
...@@ -58,11 +65,10 @@ def compile_cutils(): ...@@ -58,11 +65,10 @@ def compile_cutils():
#if defined(%(typen)s) #if defined(%(typen)s)
%(typen)s, %(typen)s,
#endif #endif
""" % {'type' : t, 'typen' : t.upper()} for t in types+complex_types]) + """ % {'type': t, 'typen': t.upper()}
for t in types + complex_types]) +
"-1000};") "-1000};")
code = (""" code = ("""
#include <Python.h> #include <Python.h>
#include "numpy/arrayobject.h" #include "numpy/arrayobject.h"
...@@ -121,7 +127,6 @@ map_increment(PyArrayMapIterObject *mit, PyObject *op, inplace_map_binop add_inp ...@@ -121,7 +127,6 @@ map_increment(PyArrayMapIterObject *mit, PyObject *op, inplace_map_binop add_inp
PyArray_BroadcastToShape((PyObject*)arr, mit->dimensions, mit->nd); PyArray_BroadcastToShape((PyObject*)arr, mit->dimensions, mit->nd);
if (it == NULL) { if (it == NULL) {
Py_DECREF(arr); Py_DECREF(arr);
return -1; return -1;
} }
...@@ -210,10 +215,9 @@ fail: ...@@ -210,10 +215,9 @@ fail:
{NULL, NULL, 0, NULL} /* Sentinel */ {NULL, NULL, 0, NULL} /* Sentinel */
};""") };""")
if PY3: if PY3:
# This is not the most efficient code, but it is written this way to highlight # This is not the most efficient code, but it is written this way to
# the changes needed to make 2.x code compile under python 3. # highlight the changes needed to make 2.x code compile under python 3.
code = code.replace("<Python.h>", '"numpy/npy_3kcompat.h"', 1) code = code.replace("<Python.h>", '"numpy/npy_3kcompat.h"', 1)
code = code.replace("PyCObject", "NpyCapsule") code = code.replace("PyCObject", "NpyCapsule")
code += """ code += """
...@@ -242,7 +246,6 @@ fail: ...@@ -242,7 +246,6 @@ fail:
} //extern C } //extern C
""" """
import cmodule import cmodule
loc = os.path.join(config.compiledir, 'cutils_ext') loc = os.path.join(config.compiledir, 'cutils_ext')
if not os.path.exists(loc): if not os.path.exists(loc):
......
...@@ -23,7 +23,8 @@ import numpy ...@@ -23,7 +23,8 @@ import numpy
import theano import theano
from theano.compat import PY3 from theano.compat import PY3
from theano import gof from theano import gof
from theano.gof import Op, utils, Variable, Constant, Type, Apply, FunctionGraph from theano.gof import (Op, utils, Variable, Constant, Type, Apply,
FunctionGraph)
from theano.gof.python25 import partial, all, any from theano.gof.python25 import partial, all, any
from theano.configparser import config from theano.configparser import config
...@@ -2680,7 +2681,7 @@ class Composite(ScalarOp): ...@@ -2680,7 +2681,7 @@ class Composite(ScalarOp):
except AttributeError: except AttributeError:
if 0: if 0:
l = [] l = []
for n in fgraph.toposort(): for n in self.fgraph.toposort():
if hasattr(n.op, "name") and n.op.name is not None: if hasattr(n.op, "name") and n.op.name is not None:
v = n.op.name v = n.op.name
if v.startswith("Composite"): if v.startswith("Composite"):
......
...@@ -24,7 +24,7 @@ from theano import compile, printing ...@@ -24,7 +24,7 @@ from theano import compile, printing
from theano.printing import pprint, min_informative_str from theano.printing import pprint, min_informative_str
from theano.tensor.utils import hash_from_ndarray from theano.tensor.utils import hash_from_ndarray
import theano.gof.cutils #needed to import cutils_ext import theano.gof.cutils # needed to import cutils_ext
try: try:
from cutils_ext.cutils_ext import inplace_increment from cutils_ext.cutils_ext import inplace_increment
except ImportError: except ImportError:
...@@ -1752,16 +1752,29 @@ class _tensor_py_operators: ...@@ -1752,16 +1752,29 @@ class _tensor_py_operators:
# standard indexing is used; if it fails with # standard indexing is used; if it fails with
# AdvancedIndexingError, advanced indexing # AdvancedIndexingError, advanced indexing
advanced = False advanced = False
for arg in args: axis = None
for i, arg in enumerate(args):
try: try:
arg == numpy.newaxis or Subtensor.convert(arg) arg == numpy.newaxis or Subtensor.convert(arg)
except AdvancedIndexingError: except AdvancedIndexingError:
advanced = True if advanced:
axis = None
break break
else:
advanced = True
axis = i
if advanced: if advanced:
if (len(args) == 1 and as_tensor_variable(args[0]).ndim <= 1): if (axis is not None
return advanced_subtensor1(self, *args) and numpy.all(a == slice(None) for a in args[:axis])
and numpy.all(a == slice(None) for a in args[axis + 1:])
and isinstance(args[axis], (
numpy.ndarray,
list,
TensorVariable,
TensorConstant,
theano.tensor.sharedvar.TensorSharedVariable))):
return self.take(arg, axis)
else: else:
return AdvancedSubtensor()(self, *args) return AdvancedSubtensor()(self, *args)
else: else:
...@@ -4439,7 +4452,8 @@ class Subtensor(Op): ...@@ -4439,7 +4452,8 @@ class Subtensor(Op):
slice_c = None slice_c = None
return slice(slice_a, slice_b, slice_c) return slice(slice_a, slice_b, slice_c)
# There is a bug in numpy that results in isinstance(x, int) returning False for numpy integers. # There is a bug in numpy that results in isinstance(x, int) returning
# False for numpy integers.
# See <http://projects.scipy.org/numpy/ticket/2235>. # See <http://projects.scipy.org/numpy/ticket/2235>.
elif isinstance(entry, (numpy.integer, int)): elif isinstance(entry, (numpy.integer, int)):
return entry return entry
...@@ -7198,19 +7212,21 @@ def as_index_variable(idx): ...@@ -7198,19 +7212,21 @@ def as_index_variable(idx):
raise TypeError('index must be integers') raise TypeError('index must be integers')
return idx return idx
def as_int_none_variable(x): def as_int_none_variable(x):
if x is None: if x is None:
return NoneConst return NoneConst
x = as_tensor_variable(x, ndim = 0) x = as_tensor_variable(x, ndim=0)
if x.type.dtype[:3] not in ('int', 'uin'): if x.type.dtype[:3] not in ('int', 'uin'):
raise TypeError('index must be integers') raise TypeError('index must be integers')
return x return x
class MakeSlice(Op): class MakeSlice(Op):
def make_node(self, slc): def make_node(self, slc):
return Apply(self, return Apply(self,
map(as_int_none_variable,[slc.start, slc.stop, slc.step]), map(as_int_none_variable,
[slc.start, slc.stop, slc.step]),
[slicetype()]) [slicetype()])
def perform(self, node, inp, out_): def perform(self, node, inp, out_):
...@@ -7227,7 +7243,7 @@ class MakeSlice(Op): ...@@ -7227,7 +7243,7 @@ class MakeSlice(Op):
return hash(type(self)) return hash(type(self))
def grad(self, inputs, grads): def grad(self, inputs, grads):
return [DiconnectedType()() for i in inputs] return [DisconnectedType()() for i in inputs]
make_slice = MakeSlice() make_slice = MakeSlice()
...@@ -7244,7 +7260,6 @@ class SliceType(gof.Type): ...@@ -7244,7 +7260,6 @@ class SliceType(gof.Type):
return "slice" return "slice"
class NoneTypeT(gof.Type): class NoneTypeT(gof.Type):
def filter(self, x, strict=False, allow_downcast=None): def filter(self, x, strict=False, allow_downcast=None):
...@@ -7257,13 +7272,16 @@ class NoneTypeT(gof.Type): ...@@ -7257,13 +7272,16 @@ class NoneTypeT(gof.Type):
return "None" return "None"
slicetype = SliceType() slicetype = SliceType()
NoneConst = Constant(NoneTypeT(), None, name = 'None') NoneConst = Constant(NoneTypeT(), None, name='None')
def adv_index_broadcastable_pattern(a, idx): def adv_index_broadcastable_pattern(a, idx):
""" """
This function is only used to determine the broardcast pattern for AdvancedSubtensor output variable. This function is only used to determine the broadcast pattern for
AdvancedSubtensor output variable.
For this, we make a fake ndarray and a fake idx and call use ask numpy the output. From this, we find the output broadcast pattern. For this, we make a fake ndarray and a fake idx and call use ask numpy
the output. From this, we find the output broadcast pattern.
""" """
def replace_slice(v): def replace_slice(v):
...@@ -7278,9 +7296,9 @@ def adv_index_broadcastable_pattern(a, idx): ...@@ -7278,9 +7296,9 @@ def adv_index_broadcastable_pattern(a, idx):
if NoneConst.equals(v): if NoneConst.equals(v):
return None return None
if isinstance(v.type, SliceType): if isinstance(v.type, SliceType):
return slice(None,None) return slice(None, None)
return numpy.zeros( (2,)* v.ndim, int) return numpy.zeros((2,) * v.ndim, int)
newidx = tuple(map(replace_slice, idx)) newidx = tuple(map(replace_slice, idx))
...@@ -7289,6 +7307,7 @@ def adv_index_broadcastable_pattern(a, idx): ...@@ -7289,6 +7307,7 @@ def adv_index_broadcastable_pattern(a, idx):
retshape = numpy.empty(fakeshape)[newidx].shape retshape = numpy.empty(fakeshape)[newidx].shape
return tuple([dim == 1 for dim in retshape]) return tuple([dim == 1 for dim in retshape])
class AdvancedSubtensor(Op): class AdvancedSubtensor(Op):
"""Return a subtensor copy, using advanced indexing. """Return a subtensor copy, using advanced indexing.
""" """
...@@ -7309,13 +7328,11 @@ class AdvancedSubtensor(Op): ...@@ -7309,13 +7328,11 @@ class AdvancedSubtensor(Op):
x = as_tensor_variable(x) x = as_tensor_variable(x)
index = tuple(map(as_index_variable, index)) index = tuple(map(as_index_variable, index))
bcast = adv_index_broadcastable_pattern(x, index)
return gof.Apply(self, return gof.Apply(self,
(x,) + index, (x,) + index,
[tensor(dtype = x.type.dtype, [tensor(dtype=x.type.dtype,
broadcastable = adv_index_broadcastable_pattern(x, index) )]) broadcastable=bcast)])
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
if eval_points[0] is None: if eval_points[0] is None:
...@@ -7392,11 +7409,6 @@ class AdvancedIncSubtensor(Op): ...@@ -7392,11 +7409,6 @@ class AdvancedIncSubtensor(Op):
self.allow_legacy_perform = False self.allow_legacy_perform = False
@classmethod
@property
def increment_available():
return inplace_increment is not None
def __hash__(self): def __hash__(self):
return hash((type(self), self.inplace, self.set_instead_of_inc)) return hash((type(self), self.inplace, self.set_instead_of_inc))
...@@ -7417,7 +7429,7 @@ class AdvancedIncSubtensor(Op): ...@@ -7417,7 +7429,7 @@ class AdvancedIncSubtensor(Op):
op = self op = self
# If we are incrementing, but the increment compiled function is not # If we are incrementing, but the increment compiled function is not
# available, we need to support legacy cases. # available, we need to support legacy cases.
if not self.set_instead_of_inc and not self.increment_available: if not self.set_instead_of_inc and inplace_increment is None:
legacy_conditions = False legacy_conditions = False
if x.ndim == 2 and y.ndim == 1 and len(inputs) == 2: if x.ndim == 2 and y.ndim == 1 and len(inputs) == 2:
ind1 = as_tensor_variable(inputs[0]) ind1 = as_tensor_variable(inputs[0])
......
...@@ -43,7 +43,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as, ...@@ -43,7 +43,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as,
ScalarFromTensor, TensorFromScalar, dtensor4, Rebroadcast, Alloc, ScalarFromTensor, TensorFromScalar, dtensor4, Rebroadcast, Alloc,
dtensor3, SpecifyShape, Mean, IncSubtensor, AdvancedIncSubtensor1, dtensor3, SpecifyShape, Mean, IncSubtensor, AdvancedIncSubtensor1,
itensor3, Tile, AdvancedIncSubtensor, switch, Diagonal, Diag, itensor3, Tile, AdvancedIncSubtensor, switch, Diagonal, Diag,
nonzero, flatnonzero, nonzero_values) nonzero, flatnonzero, nonzero_values, inplace_increment)
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
...@@ -3750,7 +3750,7 @@ class TestAdvancedSubtensor(unittest.TestCase): ...@@ -3750,7 +3750,7 @@ class TestAdvancedSubtensor(unittest.TestCase):
a.broadcastable, self.ix2.broadcastable) a.broadcastable, self.ix2.broadcastable)
def test_inc_adv_selection(self): def test_inc_adv_selection(self):
if not AdvancedIncSubtensor.increment_available: if inplace_increment is None:
raise SkipTest("inc_subtensor with advanced indexing not enabled. " raise SkipTest("inc_subtensor with advanced indexing not enabled. "
"Installing NumPy 1.8 or the latest development version " "Installing NumPy 1.8 or the latest development version "
"should make that feature available.") "should make that feature available.")
...@@ -3764,7 +3764,7 @@ class TestAdvancedSubtensor(unittest.TestCase): ...@@ -3764,7 +3764,7 @@ class TestAdvancedSubtensor(unittest.TestCase):
assert numpy.allclose(aval, [.4, .9 * 3, .1 * 3]) assert numpy.allclose(aval, [.4, .9 * 3, .1 * 3])
def test_inc_adv_selection2(self): def test_inc_adv_selection2(self):
if not AdvancedIncSubtensor.increment_available: if inplace_increment is None:
raise SkipTest("inc_subtensor with advanced indexing not enabled. " raise SkipTest("inc_subtensor with advanced indexing not enabled. "
"Installing NumPy 1.8 or the latest development version " "Installing NumPy 1.8 or the latest development version "
"should make that feature available.") "should make that feature available.")
...@@ -3786,7 +3786,7 @@ class TestAdvancedSubtensor(unittest.TestCase): ...@@ -3786,7 +3786,7 @@ class TestAdvancedSubtensor(unittest.TestCase):
[.5, .3 * 2, .15]]), aval [.5, .3 * 2, .15]]), aval
def test_inc_adv_selection_with_broadcasting(self): def test_inc_adv_selection_with_broadcasting(self):
if not AdvancedIncSubtensor.increment_available: if inplace_increment is None:
raise SkipTest("inc_subtensor with advanced indexing not enabled. " raise SkipTest("inc_subtensor with advanced indexing not enabled. "
"Installing NumPy 1.8 or the latest development version " "Installing NumPy 1.8 or the latest development version "
"should make that feature available.") "should make that feature available.")
...@@ -7425,6 +7425,8 @@ class TestTensorInstanceMethods(unittest.TestCase): ...@@ -7425,6 +7425,8 @@ class TestTensorInstanceMethods(unittest.TestCase):
self.assertRaises(TypeError, X.take, [0.0]) self.assertRaises(TypeError, X.take, [0.0])
indices = [[1,0,1], [0,1,1]] indices = [[1,0,1], [0,1,1]]
assert_array_equal(X.take(indices, 1).eval({X: x}), x.take(indices, 1)) assert_array_equal(X.take(indices, 1).eval({X: x}), x.take(indices, 1))
# Test equivalent advanced indexing
assert_array_equal(X[:,indices].eval({X: x}), x[:,indices])
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论