提交 51a3f3a5 authored 作者: Jeremiah Lowin's avatar Jeremiah Lowin

changes made for clarity, per code review

上级 cbee7123
...@@ -243,7 +243,7 @@ class TestComputeTestValue(unittest.TestCase): ...@@ -243,7 +243,7 @@ class TestComputeTestValue(unittest.TestCase):
except ValueError, e: except ValueError, e:
# Get traceback # Get traceback
tb = sys.exc_info()[2] tb = sys.exc_info()[2]
# Get frame info 3 layers up # Get frame info 4 layers up
frame_info = traceback.extract_tb(tb)[-5] frame_info = traceback.extract_tb(tb)[-5]
# We should be in the "fx" function defined above # We should be in the "fx" function defined above
assert os.path.split(frame_info[0])[1] == 'test_compute_test_value.py' assert os.path.split(frame_info[0])[1] == 'test_compute_test_value.py'
......
...@@ -6872,12 +6872,14 @@ class Dot(Op): ...@@ -6872,12 +6872,14 @@ class Dot(Op):
equivalent to matrix multiplication. For two vectors, this is the inner equivalent to matrix multiplication. For two vectors, this is the inner
product. product.
:note: matrix-matrix products are sometimes optimized to Dot22 ops :note: matrix-matrix products are sometimes optimized to Dot22 or Gemm ops.
(see tensor.blas) (see tensor.blas)
:note: non matrix-matrix products (including matrix-vector :note: vector-vector products are sometimes optimized to Ger or CGer. (see
products) are handled by numpy. Ensure that you have linked numpy tensor.blas)
with a fast BLAS.
:note: matrix-vector products are sometimes optimized to Gemv, CGemv (see
tensor.blas)
""" """
...@@ -7082,11 +7084,14 @@ def dot(a, b): ...@@ -7082,11 +7084,14 @@ def dot(a, b):
3. If both a and b have either 1 or 2 dimensions, it calls Theano's 3. If both a and b have either 1 or 2 dimensions, it calls Theano's
Dot op on a and b. Dot op on a and b.
:note: matrix-matrix products are sometimes optimized to Dot22 ops. :note: matrix-matrix products are sometimes optimized to Dot22 or Gemm ops.
(see tensor.blas)
:note: vector-vector products are sometimes optimized to Ger or CGer. (see
tensor.blas)
:note: non matrix-matrix products (including matrix-vector :note: matrix-vector products are sometimes optimized to Gemv, CGemv (see
products) are handled by numpy. Ensure that you have linked numpy tensor.blas)
with a fast BLAS.
""" """
a, b = as_tensor_variable(a), as_tensor_variable(b) a, b = as_tensor_variable(a), as_tensor_variable(b)
...@@ -7103,7 +7108,6 @@ def dot(a, b): ...@@ -7103,7 +7108,6 @@ def dot(a, b):
######################### #########################
# Linalg : TensorDot # Linalg : TensorDot
######################### #########################
# TODO: tensordot should be function as described in rst docs.
def tensordot(a, b, axes = 2): def tensordot(a, b, axes = 2):
""" """
...@@ -7195,15 +7199,18 @@ def tensordot(a, b, axes = 2): ...@@ -7195,15 +7199,18 @@ def tensordot(a, b, axes = 2):
# axes must be a scalar or list/tuple of length 2 # axes must be a scalar or list/tuple of length 2
if not numpy.isscalar(axes) and len(axes) != 2: if not numpy.isscalar(axes) and len(axes) != 2:
raise ValueError('Axes should be scalar valued or a ' raise ValueError('Axes should be scalar valued or a '
'list/tuple of len 2.') 'list/tuple of len 2 (%r was provided)' % axes)
# if 'axes' is a number of axes to multiply and sum over (trailing axes # if 'axes' is a number of axes to multiply and sum over (trailing axes
# of a, leading axes of b), we can just reshape and use dot. # of a, leading axes of b), we can just reshape and use dot.
elif numpy.isscalar(axes): elif numpy.isscalar(axes):
# check if axes is valid given the dimension of a and b # check if axes is valid given the dimension of a and b
if axes > a.ndim or axes > b.ndim: if axes > a.ndim:
raise ValueError('axes should be smaller than the dimension of '
'a (a.ndim=%i, axes=%i)' % (a.ndim, axes))
if axes > b.ndim:
raise ValueError('axes should be smaller than the dimension of ' raise ValueError('axes should be smaller than the dimension of '
'a and b (a.ndim=%i, b.ndim=%i)' % (a.ndim, b.ndim)) 'b (b.ndim=%i, axes=%i)' % (b.ndim, axes))
outshape = concatenate([a.shape[:a.ndim - axes], b.shape[axes:]]) outshape = concatenate([a.shape[:a.ndim - axes], b.shape[axes:]])
outndim = a.ndim + b.ndim - (2 * axes) outndim = a.ndim + b.ndim - (2 * axes)
......
...@@ -4310,19 +4310,11 @@ class t_dot(unittest.TestCase): ...@@ -4310,19 +4310,11 @@ class t_dot(unittest.TestCase):
#numpy return matrix not aligned... #numpy return matrix not aligned...
def test_dot_1d_1d0(self): def test_dot_1d_1d0(self):
try: self.assertRaises(ValueError, self.cmp_dot, rand(5), rand(0))
self.cmp_dot(rand(5), rand(0))
assert False
except ValueError:
pass
#numpy return matrix not aligned... #numpy return matrix not aligned...
def test_dot_1d0_1d(self): def test_dot_1d0_1d(self):
try: self.assertRaises(ValueError, self.cmp_dot, rand(0), rand(5))
self.cmp_dot(rand(0), rand(5))
assert False
except ValueError:
pass
def test_dot_1d_2d(self): def test_dot_1d_2d(self):
self.cmp_dot(rand(6), rand(6, 7)) self.cmp_dot(rand(6), rand(6, 7))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论