提交 e0bdcce8 authored 作者: notoraptor's avatar notoraptor

Back to initial contiguousity-testing and make it more precise.

Extend new tests.
上级 268e56fd
......@@ -2149,8 +2149,14 @@ class BatchedDot(Op):
fail = sub["fail"]
# generate contiguity condition
def contiguous(var):
return "(PyArray_IS_C_CONTIGUOUS(%s) || PyArray_IS_F_CONTIGUOUS(%s))" % (var, var)
def contiguous(var, ndim):
strides = "PyArray_STRIDES(%s)" % var
return " && ".join([
" && ".join("{strides}[{i}] > 0 && {strides}[{i}] % type_size == 0"
.format(strides=strides, i=i) for i in range(ndim)),
"(%s)" % (" || ".join("{strides}[{i}] == type_size"
.format(strides=strides, i=i) for i in range(1, ndim)) or '1'),
])
x_ndim, y_ndim, z_ndim = node.inputs[0].ndim, node.inputs[1].ndim, node.outputs[0].ndim
......@@ -2165,7 +2171,7 @@ class BatchedDot(Op):
z_shape_correct = " && ".join("PyArray_DIMS(%s)[%i] == %s"
% (_z, i, dim) for i, dim in enumerate(z_dims))
z_shape = ", ".join(z_dims)
z_contiguous = contiguous(_z)
z_contiguous = contiguous(_z, z_ndim)
allocate = """
if (NULL == %(_z)s || !(%(z_shape_correct)s) || !(%(z_contiguous)s))
{
......@@ -2183,8 +2189,8 @@ class BatchedDot(Op):
# code to reallocate inputs contiguously if necessary
contiguate = []
for var in (_x, _y):
_contiguous = contiguous(var)
for var, ndim in [(_x, x_ndim), (_y, y_ndim)]:
_contiguous = contiguous(var, ndim)
contiguate.append("""
if (!(%(_contiguous)s)) {
PyArrayObject * _copy = (PyArrayObject *) PyArray_Copy(%(var)s);
......
......@@ -2783,16 +2783,22 @@ def test_batched_dot_not_contiguous():
Z = batched_dot(X, W)
f = function([X, W], Z)
w = np_genarray(30, 10, 5)
reversed_x_container = np_genarray(20, 40, 30)
x_container = reversed_x_container.T
x = x_container[::1, ::2, ::2]
assert x.shape == (30, 20, 10)
assert x.strides[0] == numpy.dtype(floatX).itemsize
assert not (x.flags['C_CONTIGUOUS'] or x.flags['F_CONTIGUOUS'])
w = np_genarray(30, 10, 5)
result = f(x, w)
ref_result = numpy.asarray(list(numpy.dot(u, v) for u, v in zip(x, w)))
utt.assert_allclose(ref_result, result)
def check_first_dim(inverted):
direction = -1 if inverted else 1
x = x_container[::direction, ::2, ::2]
assert x.shape == (30, 20, 10)
assert x.strides[0] == direction * numpy.dtype(floatX).itemsize
assert not (x.flags['C_CONTIGUOUS'] or x.flags['F_CONTIGUOUS'])
result = f(x, w)
ref_result = numpy.asarray(list(numpy.dot(u, v) for u, v in zip(x, w)))
utt.assert_allclose(ref_result, result)
for inverted in (0, 1):
yield (check_first_dim, inverted)
def test_batched_tensordot():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论