提交 268e56fd authored 作者: notoraptor's avatar notoraptor

Fix issue #4476 related to BatchedDot.

Add related test in theano/tensor/tests/test_basic.py Change C code cache version.
上级 6315fdfa
...@@ -2149,14 +2149,8 @@ class BatchedDot(Op): ...@@ -2149,14 +2149,8 @@ class BatchedDot(Op):
fail = sub["fail"] fail = sub["fail"]
# generate contiguity condition # generate contiguity condition
def contiguous(var, ndim): def contiguous(var):
strides = "PyArray_STRIDES(%s)" % var return "(PyArray_IS_C_CONTIGUOUS(%s) || PyArray_IS_F_CONTIGUOUS(%s))" % (var, var)
return " && ".join([
" && ".join("{strides}[{i}] > 0 && {strides}[{i}] % type_size == 0"
.format(strides=strides, i=i) for i in range(ndim)),
"(%s)" % " || ".join("{strides}[{i}] == type_size"
.format(strides=strides, i=i) for i in range(ndim)),
])
x_ndim, y_ndim, z_ndim = node.inputs[0].ndim, node.inputs[1].ndim, node.outputs[0].ndim x_ndim, y_ndim, z_ndim = node.inputs[0].ndim, node.inputs[1].ndim, node.outputs[0].ndim
...@@ -2171,7 +2165,7 @@ class BatchedDot(Op): ...@@ -2171,7 +2165,7 @@ class BatchedDot(Op):
z_shape_correct = " && ".join("PyArray_DIMS(%s)[%i] == %s" z_shape_correct = " && ".join("PyArray_DIMS(%s)[%i] == %s"
% (_z, i, dim) for i, dim in enumerate(z_dims)) % (_z, i, dim) for i, dim in enumerate(z_dims))
z_shape = ", ".join(z_dims) z_shape = ", ".join(z_dims)
z_contiguous = contiguous(_z, z_ndim) z_contiguous = contiguous(_z)
allocate = """ allocate = """
if (NULL == %(_z)s || !(%(z_shape_correct)s) || !(%(z_contiguous)s)) if (NULL == %(_z)s || !(%(z_shape_correct)s) || !(%(z_contiguous)s))
{ {
...@@ -2189,8 +2183,8 @@ class BatchedDot(Op): ...@@ -2189,8 +2183,8 @@ class BatchedDot(Op):
# code to reallocate inputs contiguously if necessary # code to reallocate inputs contiguously if necessary
contiguate = [] contiguate = []
for var, ndim in [(_x, x_ndim), (_y, y_ndim)]: for var in (_x, _y):
_contiguous = contiguous(var, ndim) _contiguous = contiguous(var)
contiguate.append(""" contiguate.append("""
if (!(%(_contiguous)s)) { if (!(%(_contiguous)s)) {
PyArrayObject * _copy = (PyArrayObject *) PyArray_Copy(%(var)s); PyArrayObject * _copy = (PyArrayObject *) PyArray_Copy(%(var)s);
...@@ -2309,7 +2303,7 @@ class BatchedDot(Op): ...@@ -2309,7 +2303,7 @@ class BatchedDot(Op):
def c_code_cache_version(self): def c_code_cache_version(self):
from theano.tensor.blas_headers import blas_header_version from theano.tensor.blas_headers import blas_header_version
return (1, blas_header_version()) return (2, blas_header_version())
def grad(self, inp, grads): def grad(self, inp, grads):
x, y = inp x, y = inp
......
...@@ -2771,6 +2771,30 @@ def test_batched_dot(): ...@@ -2771,6 +2771,30 @@ def test_batched_dot():
assert result.shape[0] == first_mat_val.shape[0] assert result.shape[0] == first_mat_val.shape[0]
def test_batched_dot_not_contiguous():
def np_genarray(*_shape):
size = 1
for dimsize in _shape:
size *= dimsize
return numpy.arange(size, dtype=floatX).reshape(_shape)
X = tensor3()
W = tensor3()
Z = batched_dot(X, W)
f = function([X, W], Z)
reversed_x_container = np_genarray(20, 40, 30)
x_container = reversed_x_container.T
x = x_container[::1, ::2, ::2]
assert x.shape == (30, 20, 10)
assert x.strides[0] == numpy.dtype(floatX).itemsize
assert not (x.flags['C_CONTIGUOUS'] or x.flags['F_CONTIGUOUS'])
w = np_genarray(30, 10, 5)
result = f(x, w)
ref_result = numpy.asarray(list(numpy.dot(u, v) for u, v in zip(x, w)))
utt.assert_allclose(ref_result, result)
def test_batched_tensordot(): def test_batched_tensordot():
first = theano.tensor.tensor4("first") first = theano.tensor.tensor4("first")
second = theano.tensor.tensor4("second") second = theano.tensor.tensor4("second")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论