提交 08d90db3 authored 作者: Gijs van Tulder's avatar Gijs van Tulder

Let boolean indexing work with GpuArrayVariables.

上级 01c653dc
......@@ -365,25 +365,37 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin):
def test_boolean(self):
numpy_n = np.arange(6, dtype=self.dtype).reshape((2, 3))
n = self.shared(numpy_n)
# indexing with a comparison (should translate to a boolean mask)
assert_array_equal(numpy_n[numpy_n > 2], n[n > 2].eval())
assert_array_equal(numpy_n[[0], numpy_n[0] > 2], n[[0], n[0] > 2].eval())
assert_array_equal(numpy_n[[1], numpy_n[0] > 2], n[[1], n[0] > 2].eval())
# indexing with a mask for some dimensions
mask = np.array([True, False])
assert_array_equal(numpy_n[mask], n[mask].eval())
# indexing with a mask for the second dimension
mask = np.array([True, False, True])
assert_array_equal(numpy_n[0, mask], n[0, mask].eval())
assert_array_equal(numpy_n[:, mask], n[:, mask].eval())
assert_array_equal(numpy_n[:, mask], n[:, self.shared(mask)].eval())
# indexing with a boolean array
mask = np.array([[True, False, True], [False, False, True]])
assert_array_equal(numpy_n[mask], n[mask].eval())
assert_array_equal(numpy_n[mask], n[self.shared(mask)].eval())
# for the GpuArray tests, self.shared will create GpuArray variables,
# but we also need to test the combination of GPU and normal TensorVariables
n_cpu = theano.shared(numpy_n)
assert_array_equal(numpy_n[mask], n[theano.shared(mask)].eval())
assert_array_equal(numpy_n[mask], n[theano.tensor.constant(mask)].eval())
assert_array_equal(numpy_n[mask], n_cpu[self.shared(mask)].eval())
# indexing with ellipsis
numpy_n4 = np.arange(48, dtype=self.dtype).reshape((2, 3, 4, 2))
n4 = self.shared(numpy_n4)
# with ellipsis
assert_array_equal(numpy_n4[numpy_n > 2, ...], n4[n > 2, ...].eval())
assert_array_equal(numpy_n4[numpy_n > 2, ..., 1], n4[n > 2, ..., 1].eval())
......
......@@ -467,8 +467,8 @@ class _tensor_py_operators(object):
# Convert boolean arrays to calls to mask.nonzero()
tmp_args = []
for arg in args:
if (isinstance(arg, (np.ndarray, TensorVariable, TensorConstant,
theano.tensor.sharedvar.TensorSharedVariable)) and
if (isinstance(arg, (np.ndarray, theano.tensor.Variable)) and
hasattr(arg, 'dtype') and hasattr(arg, 'nonzero') and
arg.dtype == 'bool'):
tmp_args += arg.nonzero()
else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论