提交 120ad7d7 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Use _dot22 instead of dot, so it is moved to gpu

上级 a0286e23
...@@ -2079,8 +2079,17 @@ def tensordot(a, b, axes=2): ...@@ -2079,8 +2079,17 @@ def tensordot(a, b, axes=2):
tensor.prod(a.shape[a.ndim - axes:]))) tensor.prod(a.shape[a.ndim - axes:])))
b_reshaped = b.reshape((tensor.prod(b.shape[:axes]), b_reshaped = b.reshape((tensor.prod(b.shape[:axes]),
tensor.prod(b.shape[axes:]))) tensor.prod(b.shape[axes:])))
return tensor.dot(a_reshaped, b_reshaped).reshape(outshape, assert a_reshaped.ndim == 2
ndim=outndim) assert b_reshaped.ndim == 2
# We use _dot22 here because:
# - we know that the number of dimensions will be 2
# - it makes it possible for the computation to be moved to GPU
# When cuda.opt.local_gpu_tensordot is applied, it is too late
# for the usual blas optimizations to take place.
# This will change if we decide to get rid of tensor.tensordot,
# and always use this version.
return tensor.blas._dot22(a_reshaped, b_reshaped).reshape(
outshape, ndim=outndim)
elif len(axes) == 2: elif len(axes) == 2:
# if 'axes' is a pair of axis lists, we first shuffle the axes of a and # if 'axes' is a pair of axis lists, we first shuffle the axes of a and
# b to reduce this to the first case (note the recursion). # b to reduce this to the first case (note the recursion).
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论