提交 f814c0c7 authored 作者: Jeremiah Lowin's avatar Jeremiah Lowin

GPU doesn't support -1 reshaping

上级 ecdb5d3f
...@@ -7208,14 +7208,18 @@ def tensordot(a, b, axes = 2): ...@@ -7208,14 +7208,18 @@ def tensordot(a, b, axes = 2):
outshape = concatenate([a.shape[:a.ndim - axes], b.shape[axes:]]) outshape = concatenate([a.shape[:a.ndim - axes], b.shape[axes:]])
outndim = a.ndim + b.ndim - (2 * axes) outndim = a.ndim + b.ndim - (2 * axes)
a_shape_0 = b_shape_0 = 1 a_shape_0 = b_shape_0 = a_shape_1 = b_shape_1 = 1
for s0 in range(a.ndim - axes): for s0 in range(a.ndim - axes):
a_shape_0 *= a.shape[s0] a_shape_0 *= a.shape[s0]
for s0 in range(axes): for s0 in range(axes):
b_shape_0 *= b.shape[s0] b_shape_0 *= b.shape[s0]
for s1 in range(a.ndim - axes, a.ndim):
a_shape_1 *= a.shape[s1]
for s1 in range(axes, b.ndim):
b_shape_1 *= b.shape[s1]
a_reshaped = a.reshape((a_shape_0, -1), ndim = 2) a_reshaped = a.reshape((a_shape_0, a_shape_1), ndim = 2)
b_reshaped = b.reshape((b_shape_0, -1), ndim = 2) b_reshaped = b.reshape((b_shape_0, b_shape_1), ndim = 2)
return _dot(a_reshaped, b_reshaped).reshape(outshape, outndim) return _dot(a_reshaped, b_reshaped).reshape(outshape, outndim)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论