提交 8cd678da authored 作者: Tim Cooijmans's avatar Tim Cooijmans

BatchedDot: small improvements

上级 0b89695c
......@@ -3547,8 +3547,8 @@ class BatchedDot(Op):
allocate = """
if ((NULL == %(_z)s) || !(%(z_shape_correct)s))
{
npy_intp dims[3] = {%(z_shape)s};
if (NULL != %(_z)s) Py_XDECREF(%(_z)s);
npy_intp dims[%(z_ndim)s] = {%(z_shape)s};
Py_XDECREF(%(_z)s);
%(_z)s = (PyArrayObject*)PyArray_SimpleNew(
%(z_ndim)s, dims, PyArray_TYPE(%(_x)s));
if(!%(_z)s) {
......@@ -3586,7 +3586,7 @@ class BatchedDot(Op):
for axis in shape)
return """{
npy_intp dims[3] = {%(_shape)s};
PyArray_Dims newshape = {.ptr = dims, .len = 3};
PyArray_Dims newshape = {dims, 3};
%(newname)s = (PyArrayObject*)PyArray_Newshape(%(oldname)s, &newshape, NPY_ANYORDER);
if (!%(newname)s)
%(_fail)s
......@@ -3610,8 +3610,7 @@ class BatchedDot(Op):
upcast.append("ys = %(_y)s; Py_XINCREF(ys);")
elif y_ndim == 2:
upcast.append(c_dimshuffle("ys", _y, (0, 1, None)))
# upcast of z depends on shapes of both inputs
if x_ndim == 3 and y_ndim == 3:
if z_ndim == 3:
upcast.append("zs = %(_z)s; Py_XINCREF(zs);")
else:
upcast.append(c_dimshuffle(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论