提交 23040281 authored 作者: John Salvatier's avatar John Salvatier

fixed changes nouiz suggested

import os, sys import os
import sys
from theano.compat import PY3 from theano.compat import PY3
from theano.gof.compilelock import get_lock, release_lock from theano.gof.compilelock import get_lock, release_lock
...@@ -14,11 +15,13 @@ if os.path.exists(os.path.join(config.compiledir, 'cutils_ext.so')): ...@@ -14,11 +15,13 @@ if os.path.exists(os.path.join(config.compiledir, 'cutils_ext.so')):
def compile_cutils(): def compile_cutils():
"""Do just the compilation of cutils_ext""" """Do just the compilation of cutils_ext"""
types = ['npy_' + t for t in ['int8', 'int16', 'int32', 'int64', 'int128',
'int256', 'uint8', 'uint16', 'uint32', 'uint64', 'uint128', 'uint256',
'float16', 'float32', 'float64', 'float80', 'float96', 'float128',
'float256']]
complex_types = ['npy_' + t for t in ['complex32', 'complex64',
types = ['npy_'+t for t in ['int8', 'int16', 'int32', 'int64', 'int128', 'int256', 'uint8', 'uint16', 'uint32', 'uint64', 'uint128', 'uint256', 'float16', 'float32', 'float64', 'float80', 'float96', 'float128', 'float256'] ] 'complex128', 'complex160', 'complex192', 'complex512']]
complex_types = ['npy_'+t for t in ['complex32', 'complex64', 'complex128', 'complex160', 'complex192', 'complex512'] ]
inplace_map_template = """ inplace_map_template = """
#if defined(%(typen)s) #if defined(%(typen)s)
...@@ -37,33 +40,36 @@ def compile_cutils(): ...@@ -37,33 +40,36 @@ def compile_cutils():
floatadd = "((%(type)s*)mit->dataptr)[0] = ((%(type)s*)mit->dataptr)[0] + ((%(type)s*)it->dataptr)[0];" floatadd = "((%(type)s*)mit->dataptr)[0] = ((%(type)s*)mit->dataptr)[0] + ((%(type)s*)it->dataptr)[0];"
complexadd = """ complexadd = """
((%(type)s*)mit->dataptr)[0].real = ((%(type)s*)mit->dataptr)[0].real + ((%(type)s*)it->dataptr)[0].real; ((%(type)s*)mit->dataptr)[0].real = ((%(type)s*)mit->dataptr)[0].real + ((%(type)s*)it->dataptr)[0].real;
((%(type)s*)mit->dataptr)[0].imag = ((%(type)s*)mit->dataptr)[0].imag + ((%(type)s*)it->dataptr)[0].imag; ((%(type)s*)mit->dataptr)[0].imag = ((%(type)s*)mit->dataptr)[0].imag + ((%(type)s*)it->dataptr)[0].imag;
""" """
fns = ''.join([inplace_map_template % {'type': t, 'typen': t.upper(),
'op': floatadd % {'type': t}}
for t in types] +
[inplace_map_template % {'type': t, 'typen': t.upper(),
'op': complexadd % {'type': t}}
for t in complex_types])
fns = ''.join([inplace_map_template % {'type' : t, 'typen' : t.upper(), 'op' : floatadd % {'type' : t} } for t in types] + fn_array = ("inplace_map_binop addition_funcs[] = {" +
[inplace_map_template % {'type' : t, 'typen' : t.upper(), 'op' : complexadd % {'type' : t} } for t in complex_types])
fn_array = ("inplace_map_binop addition_funcs[] = {" +
''.join([""" ''.join(["""
#if defined(%(typen)s) #if defined(%(typen)s)
%(type)s_inplace_add, %(type)s_inplace_add,
#endif #endif
""" % {'type' : t, 'typen' : t.upper()} for t in types+complex_types]) + """ % {'type': t, 'typen': t.upper()}
for t in types + complex_types]) +
"""NULL}; """NULL};
""") """)
type_number_array = ("int type_numbers[] = {" + type_number_array = ("int type_numbers[] = {" +
''.join([""" ''.join(["""
#if defined(%(typen)s) #if defined(%(typen)s)
%(typen)s, %(typen)s,
#endif #endif
""" % {'type' : t, 'typen' : t.upper()} for t in types+complex_types]) + """ % {'type': t, 'typen': t.upper()}
for t in types + complex_types]) +
"-1000};") "-1000};")
code = (""" code = ("""
#include <Python.h> #include <Python.h>
#include "numpy/arrayobject.h" #include "numpy/arrayobject.h"
...@@ -91,7 +97,7 @@ def compile_cutils(): ...@@ -91,7 +97,7 @@ def compile_cutils():
#if NPY_API_VERSION >= 0x00000008 #if NPY_API_VERSION >= 0x00000008
typedef void (*inplace_map_binop)(PyArrayMapIterObject *, PyArrayIterObject *); typedef void (*inplace_map_binop)(PyArrayMapIterObject *, PyArrayIterObject *);
""" + fns + fn_array + type_number_array + """ + fns + fn_array + type_number_array +
""" """
static int static int
...@@ -111,7 +117,7 @@ map_increment(PyArrayMapIterObject *mit, PyObject *op, inplace_map_binop add_inp ...@@ -111,7 +117,7 @@ map_increment(PyArrayMapIterObject *mit, PyObject *op, inplace_map_binop add_inp
return -1; return -1;
} }
if ((mit->subspace != NULL) && (mit->consec)) { if ((mit->subspace != NULL) && (mit->consec)) {
if (mit->iteraxes[0] > 0) { if (mit->iteraxes[0] > 0) {
PyArray_MapIterSwapAxes(mit, (PyArrayObject **)&arr, 0); PyArray_MapIterSwapAxes(mit, (PyArrayObject **)&arr, 0);
if (arr == NULL) { if (arr == NULL) {
return -1; return -1;
...@@ -121,8 +127,7 @@ map_increment(PyArrayMapIterObject *mit, PyObject *op, inplace_map_binop add_inp ...@@ -121,8 +127,7 @@ map_increment(PyArrayMapIterObject *mit, PyObject *op, inplace_map_binop add_inp
it = (PyArrayIterObject*) it = (PyArrayIterObject*)
PyArray_BroadcastToShape((PyObject*)arr, mit->dimensions, mit->nd); PyArray_BroadcastToShape((PyObject*)arr, mit->dimensions, mit->nd);
if (it == NULL) { if (it == NULL) {
Py_DECREF(arr); Py_DECREF(arr);
return -1; return -1;
} }
...@@ -139,13 +144,13 @@ inplace_increment(PyObject *dummy, PyObject *args) ...@@ -139,13 +144,13 @@ inplace_increment(PyObject *dummy, PyObject *args)
{ {
PyObject *arg_a = NULL, *index=NULL, *inc=NULL; PyObject *arg_a = NULL, *index=NULL, *inc=NULL;
PyArrayObject *a; PyArrayObject *a;
inplace_map_binop add_inplace = NULL; inplace_map_binop add_inplace = NULL;
int type_number = -1; int type_number = -1;
int i =0; int i =0;
PyArrayMapIterObject * mit; PyArrayMapIterObject * mit;
if (!PyArg_ParseTuple(args, "OOO", &arg_a, &index, if (!PyArg_ParseTuple(args, "OOO", &arg_a, &index,
&inc)) { &inc)) {
return NULL; return NULL;
} }
if (!PyArray_Check(arg_a)) { if (!PyArray_Check(arg_a)) {
...@@ -154,29 +159,29 @@ inplace_increment(PyObject *dummy, PyObject *args) ...@@ -154,29 +159,29 @@ inplace_increment(PyObject *dummy, PyObject *args)
} }
a = (PyArrayObject *) arg_a; a = (PyArrayObject *) arg_a;
if (PyArray_FailUnlessWriteable(a, "input/output array") < 0) { if (PyArray_FailUnlessWriteable(a, "input/output array") < 0) {
return NULL; return NULL;
} }
if (PyArray_NDIM(a) == 0) { if (PyArray_NDIM(a) == 0) {
PyErr_SetString(PyExc_IndexError, "0-d arrays can't be indexed."); PyErr_SetString(PyExc_IndexError, "0-d arrays can't be indexed.");
return NULL; return NULL;
} }
type_number = PyArray_TYPE(a); type_number = PyArray_TYPE(a);
while (type_numbers[i] >= 0 && addition_funcs[i] != NULL){ while (type_numbers[i] >= 0 && addition_funcs[i] != NULL){
if (type_number == type_numbers[i]) { if (type_number == type_numbers[i]) {
add_inplace = addition_funcs[i]; add_inplace = addition_funcs[i];
break; break;
} }
i++ ; i++ ;
} }
if (add_inplace == NULL) { if (add_inplace == NULL) {
PyErr_SetString(PyExc_TypeError, "unsupported type for a"); PyErr_SetString(PyExc_TypeError, "unsupported type for a");
return NULL; return NULL;
} }
mit = (PyArrayMapIterObject *) PyArray_MapIterArray(a, index); mit = (PyArrayMapIterObject *) PyArray_MapIterArray(a, index);
...@@ -186,9 +191,9 @@ inplace_increment(PyObject *dummy, PyObject *args) ...@@ -186,9 +191,9 @@ inplace_increment(PyObject *dummy, PyObject *args)
if (map_increment(mit, inc, add_inplace) != 0) { if (map_increment(mit, inc, add_inplace) != 0) {
goto fail; goto fail;
} }
Py_DECREF(mit); Py_DECREF(mit);
Py_INCREF(Py_None); Py_INCREF(Py_None);
return Py_None; return Py_None;
...@@ -204,17 +209,16 @@ fail: ...@@ -204,17 +209,16 @@ fail:
{"run_cthunk", run_cthunk, METH_VARARGS|METH_KEYWORDS, {"run_cthunk", run_cthunk, METH_VARARGS|METH_KEYWORDS,
"Run a theano cthunk."}, "Run a theano cthunk."},
#if NPY_API_VERSION >= 0x00000008 #if NPY_API_VERSION >= 0x00000008
{"inplace_increment", inplace_increment, {"inplace_increment", inplace_increment,
METH_VARARGS, METH_VARARGS,
"increments a numpy array inplace at the passed indexes."}, "increments a numpy array inplace at the passed indexes."},
#endif #endif
{NULL, NULL, 0, NULL} /* Sentinel */ {NULL, NULL, 0, NULL} /* Sentinel */
};""") };""")
if PY3: if PY3:
# This is not the most efficient code, but it is written this way to highlight # This is not the most efficient code, but it is written this way to
# the changes needed to make 2.x code compile under python 3. # highlight the changes needed to make 2.x code compile under python 3.
code = code.replace("<Python.h>", '"numpy/npy_3kcompat.h"', 1) code = code.replace("<Python.h>", '"numpy/npy_3kcompat.h"', 1)
code = code.replace("PyCObject", "NpyCapsule") code = code.replace("PyCObject", "NpyCapsule")
code += """ code += """
...@@ -243,7 +247,6 @@ fail: ...@@ -243,7 +247,6 @@ fail:
} //extern C } //extern C
""" """
loc = os.path.join(config.compiledir, 'cutils_ext') loc = os.path.join(config.compiledir, 'cutils_ext')
if not os.path.exists(loc): if not os.path.exists(loc):
os.mkdir(loc) os.mkdir(loc)
......
...@@ -23,7 +23,8 @@ import numpy ...@@ -23,7 +23,8 @@ import numpy
import theano import theano
from theano.compat import PY3 from theano.compat import PY3
from theano import gof from theano import gof
from theano.gof import Op, utils, Variable, Constant, Type, Apply, FunctionGraph from theano.gof import (Op, utils, Variable, Constant, Type, Apply,
FunctionGraph)
from theano.gof.python25 import partial, all, any from theano.gof.python25 import partial, all, any
from theano.configparser import config from theano.configparser import config
...@@ -2680,7 +2681,7 @@ class Composite(ScalarOp): ...@@ -2680,7 +2681,7 @@ class Composite(ScalarOp):
except AttributeError: except AttributeError:
if 0: if 0:
l = [] l = []
for n in fgraph.toposort(): for n in self.fgraph.toposort():
if hasattr(n.op, "name") and n.op.name is not None: if hasattr(n.op, "name") and n.op.name is not None:
v = n.op.name v = n.op.name
if v.startswith("Composite"): if v.startswith("Composite"):
......
...@@ -24,9 +24,9 @@ from theano import compile, printing ...@@ -24,9 +24,9 @@ from theano import compile, printing
from theano.printing import pprint, min_informative_str from theano.printing import pprint, min_informative_str
from theano.tensor.utils import hash_from_ndarray from theano.tensor.utils import hash_from_ndarray
import theano.gof.cutils #needed to import cutils_ext import theano.gof.cutils # needed to import cutils_ext
try: try:
from cutils_ext.cutils_ext import inplace_increment from cutils_ext.cutils_ext import inplace_increment
except ImportError: except ImportError:
inplace_increment = None inplace_increment = None
...@@ -1752,16 +1752,29 @@ class _tensor_py_operators: ...@@ -1752,16 +1752,29 @@ class _tensor_py_operators:
# standard indexing is used; if it fails with # standard indexing is used; if it fails with
# AdvancedIndexingError, advanced indexing # AdvancedIndexingError, advanced indexing
advanced = False advanced = False
for arg in args: axis = None
for i, arg in enumerate(args):
try: try:
arg == numpy.newaxis or Subtensor.convert(arg) arg == numpy.newaxis or Subtensor.convert(arg)
except AdvancedIndexingError: except AdvancedIndexingError:
advanced = True if advanced:
break axis = None
break
else:
advanced = True
axis = i
if advanced: if advanced:
if (len(args) == 1 and as_tensor_variable(args[0]).ndim <= 1): if (axis is not None
return advanced_subtensor1(self, *args) and numpy.all(a == slice(None) for a in args[:axis])
and numpy.all(a == slice(None) for a in args[axis + 1:])
and isinstance(args[axis], (
numpy.ndarray,
list,
TensorVariable,
TensorConstant,
theano.tensor.sharedvar.TensorSharedVariable))):
return self.take(arg, axis)
else: else:
return AdvancedSubtensor()(self, *args) return AdvancedSubtensor()(self, *args)
else: else:
...@@ -4439,7 +4452,8 @@ class Subtensor(Op): ...@@ -4439,7 +4452,8 @@ class Subtensor(Op):
slice_c = None slice_c = None
return slice(slice_a, slice_b, slice_c) return slice(slice_a, slice_b, slice_c)
# There is a bug in numpy that results in isinstance(x, int) returning False for numpy integers. # There is a bug in numpy that results in isinstance(x, int) returning
# False for numpy integers.
# See <http://projects.scipy.org/numpy/ticket/2235>. # See <http://projects.scipy.org/numpy/ticket/2235>.
elif isinstance(entry, (numpy.integer, int)): elif isinstance(entry, (numpy.integer, int)):
return entry return entry
...@@ -7198,19 +7212,21 @@ def as_index_variable(idx): ...@@ -7198,19 +7212,21 @@ def as_index_variable(idx):
raise TypeError('index must be integers') raise TypeError('index must be integers')
return idx return idx
def as_int_none_variable(x): def as_int_none_variable(x):
if x is None: if x is None:
return NoneConst return NoneConst
x = as_tensor_variable(x, ndim = 0) x = as_tensor_variable(x, ndim=0)
if x.type.dtype[:3] not in ('int', 'uin'): if x.type.dtype[:3] not in ('int', 'uin'):
raise TypeError('index must be integers') raise TypeError('index must be integers')
return x return x
class MakeSlice(Op): class MakeSlice(Op):
def make_node(self, slc): def make_node(self, slc):
return Apply(self, return Apply(self,
map(as_int_none_variable,[slc.start, slc.stop, slc.step]), map(as_int_none_variable,
[slc.start, slc.stop, slc.step]),
[slicetype()]) [slicetype()])
def perform(self, node, inp, out_): def perform(self, node, inp, out_):
...@@ -7218,7 +7234,7 @@ class MakeSlice(Op): ...@@ -7218,7 +7234,7 @@ class MakeSlice(Op):
out[0] = slice(*inp) out[0] = slice(*inp)
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) return type(self) == type(other)
...@@ -7226,11 +7242,11 @@ class MakeSlice(Op): ...@@ -7226,11 +7242,11 @@ class MakeSlice(Op):
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self))
def grad(self, inputs, grads): def grad(self, inputs, grads):
return [DiconnectedType()() for i in inputs] return [DisconnectedType()() for i in inputs]
make_slice = MakeSlice() make_slice = MakeSlice()
class SliceType(gof.Type): class SliceType(gof.Type):
...@@ -7245,7 +7261,6 @@ class SliceType(gof.Type): ...@@ -7245,7 +7261,6 @@ class SliceType(gof.Type):
slicetype = SliceType() slicetype = SliceType()
class NoneTypeT(gof.Type): class NoneTypeT(gof.Type):
def filter(self, x, strict=False, allow_downcast=None): def filter(self, x, strict=False, allow_downcast=None):
...@@ -7257,13 +7272,16 @@ class NoneTypeT(gof.Type): ...@@ -7257,13 +7272,16 @@ class NoneTypeT(gof.Type):
def __str__(self): def __str__(self):
return "None" return "None"
NoneConst = Constant(NoneTypeT(), None, name = 'None') NoneConst = Constant(NoneTypeT(), None, name='None')
def adv_index_broadcastable_pattern(a, idx): def adv_index_broadcastable_pattern(a, idx):
""" """
This function is only used to determine the broardcast pattern for AdvancedSubtensor output variable. This function is only used to determine the broadcast pattern for
AdvancedSubtensor output variable.
For this, we make a fake ndarray and a fake idx and call use ask numpy the output. From this, we find the output broadcast pattern. For this, we make a fake ndarray and a fake idx and call use ask numpy
the output. From this, we find the output broadcast pattern.
""" """
def replace_slice(v): def replace_slice(v):
...@@ -7274,21 +7292,22 @@ def adv_index_broadcastable_pattern(a, idx): ...@@ -7274,21 +7292,22 @@ def adv_index_broadcastable_pattern(a, idx):
" to be fetched.", v) " to be fetched.", v)
else: else:
v = v.outputs[0] v = v.outputs[0]
if NoneConst.equals(v): if NoneConst.equals(v):
return None return None
if isinstance(v.type, SliceType): if isinstance(v.type, SliceType):
return slice(None,None) return slice(None, None)
return numpy.zeros( (2,)* v.ndim, int) return numpy.zeros((2,) * v.ndim, int)
newidx = tuple(map(replace_slice, idx)) newidx = tuple(map(replace_slice, idx))
#2 - True = 1; 2 - False = 2 #2 - True = 1; 2 - False = 2
fakeshape = [2 - bc for bc in a.broadcastable] fakeshape = [2 - bc for bc in a.broadcastable]
retshape = numpy.empty(fakeshape)[newidx].shape retshape = numpy.empty(fakeshape)[newidx].shape
return tuple([dim == 1 for dim in retshape]) return tuple([dim == 1 for dim in retshape])
class AdvancedSubtensor(Op): class AdvancedSubtensor(Op):
"""Return a subtensor copy, using advanced indexing. """Return a subtensor copy, using advanced indexing.
""" """
...@@ -7309,13 +7328,11 @@ class AdvancedSubtensor(Op): ...@@ -7309,13 +7328,11 @@ class AdvancedSubtensor(Op):
x = as_tensor_variable(x) x = as_tensor_variable(x)
index = tuple(map(as_index_variable, index)) index = tuple(map(as_index_variable, index))
bcast = adv_index_broadcastable_pattern(x, index)
return gof.Apply(self, return gof.Apply(self,
(x,) + index, (x,) + index,
[tensor(dtype = x.type.dtype, [tensor(dtype=x.type.dtype,
broadcastable = adv_index_broadcastable_pattern(x, index) )]) broadcastable=bcast)])
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
if eval_points[0] is None: if eval_points[0] is None:
...@@ -7392,11 +7409,6 @@ class AdvancedIncSubtensor(Op): ...@@ -7392,11 +7409,6 @@ class AdvancedIncSubtensor(Op):
self.allow_legacy_perform = False self.allow_legacy_perform = False
@classmethod
@property
def increment_available():
return inplace_increment is not None
def __hash__(self): def __hash__(self):
return hash((type(self), self.inplace, self.set_instead_of_inc)) return hash((type(self), self.inplace, self.set_instead_of_inc))
...@@ -7417,7 +7429,7 @@ class AdvancedIncSubtensor(Op): ...@@ -7417,7 +7429,7 @@ class AdvancedIncSubtensor(Op):
op = self op = self
# If we are incrementing, but the increment compiled function is not # If we are incrementing, but the increment compiled function is not
# available, we need to support legacy cases. # available, we need to support legacy cases.
if not self.set_instead_of_inc and not self.increment_available: if not self.set_instead_of_inc and inplace_increment is None:
legacy_conditions = False legacy_conditions = False
if x.ndim == 2 and y.ndim == 1 and len(inputs) == 2: if x.ndim == 2 and y.ndim == 1 and len(inputs) == 2:
ind1 = as_tensor_variable(inputs[0]) ind1 = as_tensor_variable(inputs[0])
......
...@@ -43,7 +43,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as, ...@@ -43,7 +43,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as,
ScalarFromTensor, TensorFromScalar, dtensor4, Rebroadcast, Alloc, ScalarFromTensor, TensorFromScalar, dtensor4, Rebroadcast, Alloc,
dtensor3, SpecifyShape, Mean, IncSubtensor, AdvancedIncSubtensor1, dtensor3, SpecifyShape, Mean, IncSubtensor, AdvancedIncSubtensor1,
itensor3, Tile, AdvancedIncSubtensor, switch, Diagonal, Diag, itensor3, Tile, AdvancedIncSubtensor, switch, Diagonal, Diag,
nonzero, flatnonzero, nonzero_values) nonzero, flatnonzero, nonzero_values, inplace_increment)
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
...@@ -3721,9 +3721,7 @@ class TestIncSubtensor1(unittest.TestCase): ...@@ -3721,9 +3721,7 @@ class TestIncSubtensor1(unittest.TestCase):
self.assertRaises(TypeError, self.assertRaises(TypeError,
lambda: inc_subtensor(self.v[self.adv1q], fmatrix())) lambda: inc_subtensor(self.v[self.adv1q], fmatrix()))
def check_increment_available(): inplace_increment_missing = SkipTest("inc_subtensor with advanced indexing not enabled. "
if not AdvancedIncSubtensor.increment_available:
raise SkipTest("inc_subtensor with advanced indexing not enabled. "
"Installing NumPy 1.8 or the latest development version " "Installing NumPy 1.8 or the latest development version "
"should make that feature available.") "should make that feature available.")
...@@ -3755,7 +3753,8 @@ class TestAdvancedSubtensor(unittest.TestCase): ...@@ -3755,7 +3753,8 @@ class TestAdvancedSubtensor(unittest.TestCase):
a.broadcastable, self.ix2.broadcastable) a.broadcastable, self.ix2.broadcastable)
def test_inc_adv_subtensor_w_matrix(self): def test_inc_adv_subtensor_w_matrix(self):
check_increment_available() if inplace_increment is None:
raise inplace_increment_missing
a = inc_subtensor(self.v[self.ix2], self.v[self.ix2]) a = inc_subtensor(self.v[self.ix2], self.v[self.ix2])
...@@ -3766,7 +3765,8 @@ class TestAdvancedSubtensor(unittest.TestCase): ...@@ -3766,7 +3765,8 @@ class TestAdvancedSubtensor(unittest.TestCase):
assert numpy.allclose(aval, [.4, .9 * 3, .1 * 3]) assert numpy.allclose(aval, [.4, .9 * 3, .1 * 3])
def test_inc_adv_subtensor_w_2vec(self): def test_inc_adv_subtensor_w_2vec(self):
check_increment_available() if inplace_increment is None:
raise inplace_increment_missing
subt = self.m[self.ix1, self.ix12] subt = self.m[self.ix1, self.ix12]
a = inc_subtensor(subt, subt) a = inc_subtensor(subt, subt)
...@@ -3786,7 +3786,8 @@ class TestAdvancedSubtensor(unittest.TestCase): ...@@ -3786,7 +3786,8 @@ class TestAdvancedSubtensor(unittest.TestCase):
[.5, .3 * 2, .15]]), aval [.5, .3 * 2, .15]]), aval
def test_inc_adv_subtensor_with_broadcasting(self): def test_inc_adv_subtensor_with_broadcasting(self):
check_increment_available() if inplace_increment is None:
raise inplace_increment_missing
a = inc_subtensor(self.m[self.ix1, self.ix12], 2.1) a = inc_subtensor(self.m[self.ix1, self.ix12], 2.1)
...@@ -3804,7 +3805,8 @@ class TestAdvancedSubtensor(unittest.TestCase): ...@@ -3804,7 +3805,8 @@ class TestAdvancedSubtensor(unittest.TestCase):
[.5, .3 + 2.1, .15]]), aval [.5, .3 + 2.1, .15]]), aval
def test_inc_adv_subtensor_with_index_broadcasting(self): def test_inc_adv_subtensor_with_index_broadcasting(self):
check_increment_available() if inplace_increment is None:
raise inplace_increment_missing
a = inc_subtensor(self.m[self.ix1, self.ix2], 2.1) a = inc_subtensor(self.m[self.ix1, self.ix2], 2.1)
...@@ -7441,6 +7443,8 @@ class TestTensorInstanceMethods(unittest.TestCase): ...@@ -7441,6 +7443,8 @@ class TestTensorInstanceMethods(unittest.TestCase):
self.assertRaises(TypeError, X.take, [0.0]) self.assertRaises(TypeError, X.take, [0.0])
indices = [[1,0,1], [0,1,1]] indices = [[1,0,1], [0,1,1]]
assert_array_equal(X.take(indices, 1).eval({X: x}), x.take(indices, 1)) assert_array_equal(X.take(indices, 1).eval({X: x}), x.take(indices, 1))
# Test equivalent advanced indexing
assert_array_equal(X[:,indices].eval({X: x}), x[:,indices])
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论