提交 986c6dc6 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5359 from notoraptor/fix-batched-dot

Fix issue #4476 related to BatchedDot.
......@@ -2151,11 +2151,13 @@ class BatchedDot(Op):
# generate contiguity condition
def contiguous(var, ndim):
strides = "PyArray_STRIDES(%s)" % var
if ndim == 1:
return "{strides}[0] == type_size".format(strides=strides)
return " && ".join([
" && ".join("{strides}[{i}] > 0 && {strides}[{i}] % type_size == 0"
.format(strides=strides, i=i) for i in range(ndim)),
.format(strides=strides, i=i) for i in range(1, ndim)),
"(%s)" % " || ".join("{strides}[{i}] == type_size"
.format(strides=strides, i=i) for i in range(ndim)),
.format(strides=strides, i=i) for i in range(1, ndim)),
])
x_ndim, y_ndim, z_ndim = node.inputs[0].ndim, node.inputs[1].ndim, node.outputs[0].ndim
......@@ -2309,7 +2311,7 @@ class BatchedDot(Op):
def c_code_cache_version(self):
from theano.tensor.blas_headers import blas_header_version
return (1, blas_header_version())
return (3, blas_header_version())
def grad(self, inp, grads):
x, y = inp
......
......@@ -2771,6 +2771,36 @@ def test_batched_dot():
assert result.shape[0] == first_mat_val.shape[0]
def test_batched_dot_not_contiguous():
def np_genarray(*_shape):
size = 1
for dimsize in _shape:
size *= dimsize
return numpy.arange(size, dtype=floatX).reshape(_shape)
X = tensor3()
W = tensor3()
Z = batched_dot(X, W)
f = function([X, W], Z)
w = np_genarray(30, 10, 5)
reversed_x_container = np_genarray(20, 40, 30)
x_container = reversed_x_container.T
def check_first_dim(inverted):
direction = -1 if inverted else 1
x = x_container[::direction, ::2, ::2]
assert x.shape == (30, 20, 10)
assert x.strides[0] == direction * numpy.dtype(floatX).itemsize
assert not (x.flags['C_CONTIGUOUS'] or x.flags['F_CONTIGUOUS'])
result = f(x, w)
ref_result = numpy.asarray(list(numpy.dot(u, v) for u, v in zip(x, w)))
utt.assert_allclose(ref_result, result)
for inverted in (0, 1):
yield (check_first_dim, inverted)
def test_batched_tensordot():
first = theano.tensor.tensor4("first")
second = theano.tensor.tensor4("second")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论