提交 91fe0a2b authored 作者: Tim Cooijmans's avatar Tim Cooijmans

BatchedDot: write docstrings

上级 1a285716
......@@ -3374,22 +3374,9 @@ def transpose(x, axes=None):
class BatchedDot(Op):
"""
Dot but Batched
WRITEME
Computes the dot product of two variables. For two matrices, this is
equivalent to matrix multiplication. For two vectors, this is the inner
product.
Notes
-----
Matrix-matrix products are sometimes optimized to Dot22 or Gemm ops
(see tensor.blas).
Vector-vector products are sometimes optimized to Ger or CGer (see
tensor.blas).
Matrix-vector products are sometimes optimized to Gemv, CGemv (see
tensor.blas).
Computes the batched dot product of two variables:
batched_dot(a, b)[i] = dot(a[i], b[i])
"""
__props__ = ()
......@@ -3497,7 +3484,7 @@ class BatchedDot(Op):
iv1 = gof.op.get_test_value(inputs[1])
except AttributeError:
gof.op.missing_test_message(
'second input passed to Dot.R_op has no test value')
'second input passed to BatchedDot.R_op has no test value')
debugger_available = False
if eval_points[0]:
......@@ -3505,7 +3492,7 @@ class BatchedDot(Op):
ev0 = gof.op.get_test_value(eval_points[0])
except AttributeError:
gof.op.missing_test_message(
'first eval point passed to Dot.R_op '
'first eval point passed to BatchedDot.R_op '
'has no test value')
debugger_available = False
if eval_points[1]:
......@@ -3513,7 +3500,7 @@ class BatchedDot(Op):
ev1 = gof.op.get_test_value(eval_points[1])
except AttributeError:
gof.op.missing_test_message(
'second eval point passed to Dot.R_op '
'second eval point passed to BatchedDot.R_op '
'has no test value')
debugger_available = False
......@@ -3526,7 +3513,7 @@ class BatchedDot(Op):
input_values[i].shape != eval_point_values[i].shape:
raise ValueError(
'input ' + str(i) + ' and eval_point ' + str(i) +
' to Dot.R_op should have the same shape, but '
' to BatchedDot.R_op should have the same shape, but '
'their shapes are %s and %s, respectively' % (
str(input_values[i].shape),
str(eval_point_values[i].shape)))
......@@ -3555,41 +3542,24 @@ class BatchedDot(Op):
def batched_dot(a, b):
"""
WRITEME
Computes the dot product of two variables.
For two matrices, this is equivalent to matrix multiplication.
For two vectors, this is the inner product.
When one variable is a scalar, this is like elementwise multiplication.
For N dimensions, this is a sum product over the last axis
of the first array and the second-to-last axis of the second array:
dot(a, b)[i,j,k,m] = sum(a[i,j,:] * b[k,:,m])
Compute the batched dot product of two variables:
Note that this dot function does one of three things, in the following
sequence:
batched_dot(a, b)[i] = dot(a[i], b[i])
1. If either a or b is scalar, it returns the elementwise product
without calling the Theano Dot op.
Note that this batched_dot function does one of three things, in the
following sequence:
2. If either a or b has more than 2 dimensions, it calls Theano's
tensordot function with appropriate axes. The tensordot function
expresses high-dimensional dot products in terms of 2D matrix
multiplications, so it may be possible to futherize optimize for
performance.
1. If either a or b is a vector, it returns the batched elementwise
product without calling the Theano BatchedDot op.
3. If both a and b have either 1 or 2 dimensions, it calls Theano's
Dot op on a and b.
Notes
-----
Matrix-matrix products are sometimes optimized to Dot22 or Gemm ops
(see tensor.blas).
Vector-vector products are sometimes optimized to Ger or CGer (see
tensor.blas).
Matrix-vector products are sometimes optimized to Gemv, CGemv (see
tensor.blas).
2. If both a and b have either 2 or 3 dimensions, it calls Theano's
BatchedDot op on a and b.
3. If either a or b has more than 3 dimensions, it calls Theano's
batched_tensordot function with appropriate axes. The
batched_tensordot function expresses high-dimensional batched
dot products in terms of batched matrix-matrix dot products, so
it may be possible to futherize optimize for performance.
"""
a, b = as_tensor_variable(a), as_tensor_variable(b)
......@@ -3606,11 +3576,11 @@ def batched_dot(a, b):
def batched_tensordot(x, y, axes=2):
"""
Compute the tensordot product.
Compute a batched tensordot product.
A hybrid of batch_dot and tensordot, this function computes the
A hybrid of batched_dot and tensordot, this function computes the
tensordot product between the two tensors, by iterating over the
first dimension using scan to perform a sequence of tensordots.
first dimension to perform a sequence of tensordots.
Parameters
----------
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论