提交 91fe0a2b authored 作者: Tim Cooijmans's avatar Tim Cooijmans

BatchedDot: write docstrings

上级 1a285716
...@@ -3374,22 +3374,9 @@ def transpose(x, axes=None): ...@@ -3374,22 +3374,9 @@ def transpose(x, axes=None):
class BatchedDot(Op): class BatchedDot(Op):
""" """
Dot but Batched Computes the batched dot product of two variables:
WRITEME
Computes the dot product of two variables. For two matrices, this is
equivalent to matrix multiplication. For two vectors, this is the inner
product.
Notes
-----
Matrix-matrix products are sometimes optimized to Dot22 or Gemm ops
(see tensor.blas).
Vector-vector products are sometimes optimized to Ger or CGer (see
tensor.blas).
Matrix-vector products are sometimes optimized to Gemv, CGemv (see
tensor.blas).
batched_dot(a, b)[i] = dot(a[i], b[i])
""" """
__props__ = () __props__ = ()
...@@ -3497,7 +3484,7 @@ class BatchedDot(Op): ...@@ -3497,7 +3484,7 @@ class BatchedDot(Op):
iv1 = gof.op.get_test_value(inputs[1]) iv1 = gof.op.get_test_value(inputs[1])
except AttributeError: except AttributeError:
gof.op.missing_test_message( gof.op.missing_test_message(
'second input passed to Dot.R_op has no test value') 'second input passed to BatchedDot.R_op has no test value')
debugger_available = False debugger_available = False
if eval_points[0]: if eval_points[0]:
...@@ -3505,7 +3492,7 @@ class BatchedDot(Op): ...@@ -3505,7 +3492,7 @@ class BatchedDot(Op):
ev0 = gof.op.get_test_value(eval_points[0]) ev0 = gof.op.get_test_value(eval_points[0])
except AttributeError: except AttributeError:
gof.op.missing_test_message( gof.op.missing_test_message(
'first eval point passed to Dot.R_op ' 'first eval point passed to BatchedDot.R_op '
'has no test value') 'has no test value')
debugger_available = False debugger_available = False
if eval_points[1]: if eval_points[1]:
...@@ -3513,7 +3500,7 @@ class BatchedDot(Op): ...@@ -3513,7 +3500,7 @@ class BatchedDot(Op):
ev1 = gof.op.get_test_value(eval_points[1]) ev1 = gof.op.get_test_value(eval_points[1])
except AttributeError: except AttributeError:
gof.op.missing_test_message( gof.op.missing_test_message(
'second eval point passed to Dot.R_op ' 'second eval point passed to BatchedDot.R_op '
'has no test value') 'has no test value')
debugger_available = False debugger_available = False
...@@ -3526,7 +3513,7 @@ class BatchedDot(Op): ...@@ -3526,7 +3513,7 @@ class BatchedDot(Op):
input_values[i].shape != eval_point_values[i].shape: input_values[i].shape != eval_point_values[i].shape:
raise ValueError( raise ValueError(
'input ' + str(i) + ' and eval_point ' + str(i) + 'input ' + str(i) + ' and eval_point ' + str(i) +
' to Dot.R_op should have the same shape, but ' ' to BatchedDot.R_op should have the same shape, but '
'their shapes are %s and %s, respectively' % ( 'their shapes are %s and %s, respectively' % (
str(input_values[i].shape), str(input_values[i].shape),
str(eval_point_values[i].shape))) str(eval_point_values[i].shape)))
...@@ -3555,41 +3542,24 @@ class BatchedDot(Op): ...@@ -3555,41 +3542,24 @@ class BatchedDot(Op):
def batched_dot(a, b): def batched_dot(a, b):
""" """
WRITEME Compute the batched dot product of two variables:
Computes the dot product of two variables.
For two matrices, this is equivalent to matrix multiplication.
For two vectors, this is the inner product.
When one variable is a scalar, this is like elementwise multiplication.
For N dimensions, this is a sum product over the last axis
of the first array and the second-to-last axis of the second array:
dot(a, b)[i,j,k,m] = sum(a[i,j,:] * b[k,:,m])
Note that this dot function does one of three things, in the following batched_dot(a, b)[i] = dot(a[i], b[i])
sequence:
1. If either a or b is scalar, it returns the elementwise product Note that this batched_dot function does one of three things, in the
without calling the Theano Dot op. following sequence:
2. If either a or b has more than 2 dimensions, it calls Theano's 1. If either a or b is a vector, it returns the batched elementwise
tensordot function with appropriate axes. The tensordot function product without calling the Theano BatchedDot op.
expresses high-dimensional dot products in terms of 2D matrix
multiplications, so it may be possible to futherize optimize for
performance.
3. If both a and b have either 1 or 2 dimensions, it calls Theano's 2. If both a and b have either 2 or 3 dimensions, it calls Theano's
Dot op on a and b. BatchedDot op on a and b.
Notes
-----
Matrix-matrix products are sometimes optimized to Dot22 or Gemm ops
(see tensor.blas).
Vector-vector products are sometimes optimized to Ger or CGer (see
tensor.blas).
Matrix-vector products are sometimes optimized to Gemv, CGemv (see
tensor.blas).
3. If either a or b has more than 3 dimensions, it calls Theano's
batched_tensordot function with appropriate axes. The
batched_tensordot function expresses high-dimensional batched
dot products in terms of batched matrix-matrix dot products, so
it may be possible to futherize optimize for performance.
""" """
a, b = as_tensor_variable(a), as_tensor_variable(b) a, b = as_tensor_variable(a), as_tensor_variable(b)
...@@ -3606,11 +3576,11 @@ def batched_dot(a, b): ...@@ -3606,11 +3576,11 @@ def batched_dot(a, b):
def batched_tensordot(x, y, axes=2): def batched_tensordot(x, y, axes=2):
""" """
Compute the tensordot product. Compute a batched tensordot product.
A hybrid of batch_dot and tensordot, this function computes the A hybrid of batched_dot and tensordot, this function computes the
tensordot product between the two tensors, by iterating over the tensordot product between the two tensors, by iterating over the
first dimension using scan to perform a sequence of tensordots. first dimension to perform a sequence of tensordots.
Parameters Parameters
---------- ----------
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论