提交 7680aec3 authored 作者: Jeremiah Lowin's avatar Jeremiah Lowin

create dot function to call the new tensordot function for inputs with more than 2 dimensions

上级 1ba75ec3
......@@ -7099,11 +7099,47 @@ class Dot(Op):
def __str__(self):
return "dot"
dot = Dot()
_dot = Dot()
pprint.assign(_dot, printing.OperatorPrinter(printing.special['middle_dot'],
-1, 'left'))
def dot(a, b):
"""
Computes the dot product of two variables. For two matrices, this is
equivalent to matrix multiplication. For two vectors, this is the inner
product. When one variable is a scalar, it is like elementwise
multiplication. For N dimensions, it is a sum product over the last axis
of the first array and the second-to-last axis of the second array:
dot(a, b)[i,j,k,m] = sum(a[i,j,:] * b[k,:,m])
Note that this dot function will do one of three things, in this sequence:
1. If either a or b is scalar, it returns the elementwise product
without calling the Dot op.
2. If either a or b has more than 2 dimensions, it calls the tensordot
function instead of the Dot op. Tensordot expresses high-dimensional
dot products as matrix multiplication and is faster than using a
high-dimensional Dot op.
3. Otherwise, calls the Dot op on a and b.
:note: matrix-matrix products are sometimes optimized to Dot22 ops
(see tensor.blas)
:note: non matrix-matrix products (including matrix-vector
products) are handled by numpy. Ensure that you have linked numpy
with a fast BLAS.
"""
if a.ndim == 0 or b.ndim == 0:
return a * b
elif a.ndim > 2 or b.ndim > 2:
return tensordot(a, b, [[a.ndim - 1], [b.ndim - 2]])
else:
return _dot(a, b)
pprint.assign(dot, printing.OperatorPrinter(printing.special['middle_dot'],
-1, 'left'))
#########################
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论