提交 265980e2 authored 作者: notoraptor's avatar notoraptor

Clarify code + small correction (we don't care about first dimension of inputs for BatchedDot).

上级 40dc4e31
......@@ -2151,11 +2151,13 @@ class BatchedDot(Op):
# generate contiguity condition
def contiguous(var, ndim):
strides = "PyArray_STRIDES(%s)" % var
if ndim == 1:
return "{strides}[0] == type_size".format(strides=strides)
return " && ".join([
" && ".join("{strides}[{i}] > 0 && {strides}[{i}] % type_size == 0"
.format(strides=strides, i=i) for i in range(ndim)),
"(%s)" % (" || ".join("{strides}[{i}] == type_size"
.format(strides=strides, i=i) for i in range(1, ndim)) or '1'),
.format(strides=strides, i=i) for i in range(1, ndim)),
"(%s)" % " || ".join("{strides}[{i}] == type_size"
.format(strides=strides, i=i) for i in range(1, ndim)),
])
x_ndim, y_ndim, z_ndim = node.inputs[0].ndim, node.inputs[1].ndim, node.outputs[0].ndim
......@@ -2309,7 +2311,7 @@ class BatchedDot(Op):
def c_code_cache_version(self):
from theano.tensor.blas_headers import blas_header_version
return (2, blas_header_version())
return (3, blas_header_version())
def grad(self, inp, grads):
x, y = inp
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论