提交 f654e792 authored 作者: Jeremiah Lowin's avatar Jeremiah Lowin

update unit tests to use new tensordot function, remove infer_shape tests for…

update unit tests to use new tensordot function, remove infer_shape tests for tensordot and tensordotgrad
上级 5463e79d
...@@ -32,11 +32,11 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as, ...@@ -32,11 +32,11 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as,
tensor4, permute_row_elements, Flatten, fmatrix, fscalars, grad, tensor4, permute_row_elements, Flatten, fmatrix, fscalars, grad,
inplace, iscalar, matrix, minimum, matrices, maximum, mul, neq, inplace, iscalar, matrix, minimum, matrices, maximum, mul, neq,
Reshape, row, scalar, scalars, second, smallest, stack, sub, Tensor, Reshape, row, scalar, scalars, second, smallest, stack, sub, Tensor,
tensor_copy, tensordot, tensordot_grad, TensorType, unbroadcast, tensor_copy, tensordot, TensorType, unbroadcast,
var, Join, shape, MaxAndArgmax, lscalar, zvector, exp, var, Join, shape, MaxAndArgmax, lscalar, zvector, exp,
get_constant_value, ivector, reshape, scalar_from_tensor, scal, get_constant_value, ivector, reshape, scalar_from_tensor, scal,
iscalars, arange, dscalars, fvector, imatrix, numeric_grad, iscalars, arange, dscalars, fvector, imatrix, numeric_grad,
opt, ComplexError, TensorDot, lvector, true_div, max, min, Split, roll, opt, ComplexError, lvector, true_div, max, min, Split, roll,
tile, patternbroadcast, Eye, Shape, Default, Dot, PermuteRowElements, tile, patternbroadcast, Eye, Shape, Default, Dot, PermuteRowElements,
ScalarFromTensor, TensorFromScalar, dtensor4, Rebroadcast, Alloc, ScalarFromTensor, TensorFromScalar, dtensor4, Rebroadcast, Alloc,
dtensor3, SpecifyShape, Mean, IncSubtensor, AdvancedIncSubtensor1, dtensor3, SpecifyShape, Mean, IncSubtensor, AdvancedIncSubtensor1,
...@@ -5406,6 +5406,13 @@ class TestPermuteRowElements(unittest.TestCase): ...@@ -5406,6 +5406,13 @@ class TestPermuteRowElements(unittest.TestCase):
class test_tensordot(unittest.TestCase): class test_tensordot(unittest.TestCase):
def TensorDot(self, axes):
"""
Since tensordot is no longer an op, mimic the old op signature
to allow easy use of verify_grad.
"""
return lambda a, b : tensordot(a, b, axes)
def setUp(self): def setUp(self):
utt.seed_rng() utt.seed_rng()
...@@ -5421,7 +5428,7 @@ class test_tensordot(unittest.TestCase): ...@@ -5421,7 +5428,7 @@ class test_tensordot(unittest.TestCase):
bval = rand(5) bval = rand(5)
self.assertTrue(numpy.tensordot(aval, bval, axes) == \ self.assertTrue(numpy.tensordot(aval, bval, axes) == \
f1(aval, bval)) f1(aval, bval))
utt.verify_grad(TensorDot(axes), [aval, bval]) utt.verify_grad(self.TensorDot(axes), [aval, bval])
# Test matrix-vector # Test matrix-vector
bmat = matrix() bmat = matrix()
...@@ -5432,7 +5439,7 @@ class test_tensordot(unittest.TestCase): ...@@ -5432,7 +5439,7 @@ class test_tensordot(unittest.TestCase):
bval = rand(8, 5) bval = rand(8, 5)
self.assertTrue(numpy.allclose(numpy.tensordot(aval, bval, axes), self.assertTrue(numpy.allclose(numpy.tensordot(aval, bval, axes),
f2(aval, bval))) f2(aval, bval)))
utt.verify_grad(TensorDot(axes), [aval, bval]) utt.verify_grad(self.TensorDot(axes), [aval, bval])
# Test matrix-matrix # Test matrix-matrix
amat = matrix() amat = matrix()
...@@ -5451,7 +5458,7 @@ class test_tensordot(unittest.TestCase): ...@@ -5451,7 +5458,7 @@ class test_tensordot(unittest.TestCase):
bval = rand(*shps[1]) bval = rand(*shps[1])
self.assertTrue(numpy.allclose(numpy.tensordot(aval, bval, axes), self.assertTrue(numpy.allclose(numpy.tensordot(aval, bval, axes),
f3(aval, bval))) f3(aval, bval)))
utt.verify_grad(TensorDot(axes), [aval, bval]) utt.verify_grad(self.TensorDot(axes), [aval, bval])
# Test ndarray-matrix, sum over one dim of matrix # Test ndarray-matrix, sum over one dim of matrix
for axes, shps in [[((2,), (1,)), [(1, 2, 3, 4), (2, 3)]], for axes, shps in [[((2,), (1,)), [(1, 2, 3, 4), (2, 3)]],
...@@ -5469,7 +5476,7 @@ class test_tensordot(unittest.TestCase): ...@@ -5469,7 +5476,7 @@ class test_tensordot(unittest.TestCase):
bval = rand(*shps[1]) bval = rand(*shps[1])
self.assertTrue(numpy.allclose(numpy.tensordot(aval, bval, axes), self.assertTrue(numpy.allclose(numpy.tensordot(aval, bval, axes),
f4(aval, bval))) f4(aval, bval)))
utt.verify_grad(TensorDot(axes), [aval, bval]) utt.verify_grad(self.TensorDot(axes), [aval, bval])
# Test ndarray-ndarray # Test ndarray-ndarray
atens = tensor4() atens = tensor4()
...@@ -5481,14 +5488,14 @@ class test_tensordot(unittest.TestCase): ...@@ -5481,14 +5488,14 @@ class test_tensordot(unittest.TestCase):
bval = rand(3, 4, 2) bval = rand(3, 4, 2)
self.assertTrue(numpy.allclose(numpy.tensordot(aval, bval, axes), self.assertTrue(numpy.allclose(numpy.tensordot(aval, bval, axes),
f5(aval, bval))) f5(aval, bval)))
utt.verify_grad(TensorDot(axes), [aval, bval]) utt.verify_grad(self.TensorDot(axes), [aval, bval])
axes = (axes[1], axes[0]) axes = (axes[1], axes[0])
c = tensordot(btens, atens, axes) c = tensordot(btens, atens, axes)
f6 = inplace_func([btens, atens], c) f6 = inplace_func([btens, atens], c)
self.assertTrue(numpy.allclose(numpy.tensordot(bval, aval, axes), self.assertTrue(numpy.allclose(numpy.tensordot(bval, aval, axes),
f6(bval, aval))) f6(bval, aval)))
utt.verify_grad(TensorDot(axes), [bval, aval]) utt.verify_grad(self.TensorDot(axes), [bval, aval])
def test_raise_error(self): def test_raise_error(self):
amat = matrix() amat = matrix()
...@@ -5541,7 +5548,7 @@ class test_tensordot(unittest.TestCase): ...@@ -5541,7 +5548,7 @@ class test_tensordot(unittest.TestCase):
bval = rand(7, 9) bval = rand(7, 9)
self.assertTrue(numpy.allclose(numpy.tensordot(aval, bval, axes), self.assertTrue(numpy.allclose(numpy.tensordot(aval, bval, axes),
f3(aval, bval))) f3(aval, bval)))
utt.verify_grad(TensorDot(axes), [aval, bval]) utt.verify_grad(self.TensorDot(axes), [aval, bval])
def test_scalar_axes(self): def test_scalar_axes(self):
# Test matrix-matrix # Test matrix-matrix
...@@ -5555,7 +5562,7 @@ class test_tensordot(unittest.TestCase): ...@@ -5555,7 +5562,7 @@ class test_tensordot(unittest.TestCase):
f3 = inplace_func([amat, bmat], c) f3 = inplace_func([amat, bmat], c)
self.assertTrue(numpy.allclose(numpy.tensordot(aval, bval, axes), self.assertTrue(numpy.allclose(numpy.tensordot(aval, bval, axes),
f3(aval, bval))) f3(aval, bval)))
utt.verify_grad(TensorDot(axes), [aval, bval]) utt.verify_grad(self.TensorDot(axes), [aval, bval])
# Test tensor-tensor # Test tensor-tensor
amat = tensor3() amat = tensor3()
...@@ -5567,7 +5574,7 @@ class test_tensordot(unittest.TestCase): ...@@ -5567,7 +5574,7 @@ class test_tensordot(unittest.TestCase):
f3 = inplace_func([amat, bmat], c) f3 = inplace_func([amat, bmat], c)
self.assertTrue(numpy.allclose(numpy.tensordot(aval, bval, axes), self.assertTrue(numpy.allclose(numpy.tensordot(aval, bval, axes),
f3(aval, bval))) f3(aval, bval)))
utt.verify_grad(TensorDot(axes), [aval, bval]) utt.verify_grad(self.TensorDot(axes), [aval, bval])
def test_scalar0(self): def test_scalar0(self):
# Test tensor-tensor # Test tensor-tensor
...@@ -5580,7 +5587,7 @@ class test_tensordot(unittest.TestCase): ...@@ -5580,7 +5587,7 @@ class test_tensordot(unittest.TestCase):
f3 = inplace_func([amat, bmat], c) f3 = inplace_func([amat, bmat], c)
self.assertTrue(numpy.allclose(numpy.tensordot(aval, bval, axes), self.assertTrue(numpy.allclose(numpy.tensordot(aval, bval, axes),
f3(aval, bval))) f3(aval, bval)))
utt.verify_grad(TensorDot(axes), [aval, bval]) utt.verify_grad(self.TensorDot(axes), [aval, bval])
def test_tensordot_grad(self): def test_tensordot_grad(self):
# We test it manually as we recreate the op in the make_node # We test it manually as we recreate the op in the make_node
...@@ -6429,180 +6436,6 @@ class TestInferShape(utt.InferShapeTester): ...@@ -6429,180 +6436,6 @@ class TestInferShape(utt.InferShapeTester):
def test_infer_shape(self): def test_infer_shape(self):
# tensordot_grad
admat = dmatrix()
bdmat = dmatrix()
gzdmat = dmatrix()
admat_val = rand(4, 5)
bdmat_val = rand(5, 3)
gzdmat_val = rand(4, 3)
axes = 1
self._compile_and_check([admat, bdmat, gzdmat],
tensordot_grad(axes)(admat, bdmat, gzdmat),
[admat_val, bdmat_val, gzdmat_val], tensordot_grad)
admat_val = rand(5, 4)
bdmat_val = rand(5, 4)
gzdscal = dscalar()
gzdscal_val = rand()
axes = 2
self._compile_and_check([admat, bdmat, gzdscal],
tensordot_grad(axes)(admat, bdmat, gzdscal),
[admat_val, bdmat_val, gzdscal_val], tensordot_grad)
admat_val = rand(4, 5)
bdmat_val = rand(5, 3)
gzdmat_val = rand(4, 3)
axes = ((1, ), (0, ))
self._compile_and_check([admat, bdmat, gzdmat],
tensordot_grad(axes)(admat, bdmat, gzdmat),
[admat_val, bdmat_val, gzdmat_val], tensordot_grad)
axes = ((1, 0))
self._compile_and_check([admat, bdmat, gzdmat],
tensordot_grad(axes)(admat, bdmat, gzdmat),
[admat_val, bdmat_val, gzdmat_val], tensordot_grad)
admat_val = rand(4, 5)
bdmat_val = rand(3, 4)
gzdmat_val = rand(5, 3)
axes = ((0, ), (1, ))
self._compile_and_check([admat, bdmat, gzdmat],
tensordot_grad(axes)(admat, bdmat, gzdmat),
[admat_val, bdmat_val, gzdmat_val], tensordot_grad)
gzdscal = dscalar()
admat_val = rand(5, 4)
bdmat_val = rand(5, 4)
gzdscal_val = rand()
axes = ((0, 1), (0, 1))
self._compile_and_check([admat, bdmat, gzdscal],
tensordot_grad(axes)(admat, bdmat, gzdscal),
[admat_val, bdmat_val, gzdscal_val], tensordot_grad)
# tensordot_grad currently do not support not ordered axes
"""
gzdscal = dscalar()
admat_val = rand(5, 4)
bdmat_val = rand(4, 5)
gzdscal_val = rand()
axes = ((0, 1), (1, 0))
self._compile_and_check([admat, bdmat, gzdscal],
tensordot_grad(axes)(admat, bdmat, gzdscal),
[admat_val, bdmat_val, gzdscal_val], tensordot_grad)
gzdscal = dscalar()
admat_val = rand(5, 4)
bdmat_val = rand(5, 4)
gzdscal_val = rand()
axes = ((1, 0 ), (1, 0))
self._compile_and_check([admat, bdmat, gzdscal],
tensordot_grad(axes)(admat, bdmat, gzdscal),
[admat_val, bdmat_val, gzdscal_val], tensordot_grad)
"""
# tensordot
admat = dmatrix()
bdmat = dmatrix()
admat_val = rand(4, 5)
bdmat_val = rand(5, 3)
axes = 1
self._compile_and_check([admat, bdmat],
[TensorDot(axes)(admat, bdmat)],
[admat_val, bdmat_val], TensorDot)
admat_val = rand(5, 4)
bdmat_val = rand(5, 4)
axes = 2
self._compile_and_check([admat, bdmat],
[TensorDot(axes)(admat, bdmat)],
[admat_val, bdmat_val], TensorDot)
admat_val = rand(4, 5)
bdmat_val = rand(5, 3)
axes = ((1, ), (0, ))
self._compile_and_check([admat, bdmat],
[TensorDot(axes)(admat, bdmat)],
[admat_val, bdmat_val], TensorDot)
axes = ((1, 0))
self._compile_and_check([admat, bdmat],
[TensorDot(axes)(admat, bdmat)],
[admat_val, bdmat_val], TensorDot)
admat_val = rand(4, 5)
bdmat_val = rand(3, 4)
axes = ((0, ), (1, ))
self._compile_and_check([admat, bdmat],
[TensorDot(axes)(admat, bdmat)],
[admat_val, bdmat_val], TensorDot)
axes = ((0, 1))
self._compile_and_check([admat, bdmat],
[TensorDot(axes)(admat, bdmat)],
[admat_val, bdmat_val], TensorDot)
admat_val = rand(5, 4)
bdmat_val = rand(4, 5)
axes = ((1,), (0,))
self._compile_and_check([admat, bdmat],
[TensorDot(axes)(admat, bdmat)],
[admat_val, bdmat_val], TensorDot)
axes = ((0, 1), (1, 0))
self._compile_and_check([admat, bdmat],
[TensorDot(axes)(admat, bdmat)],
[admat_val, bdmat_val], TensorDot)
admat_val = rand(5, 4)
bdmat_val = rand(5, 4)
axes = ((0, 1), (0, 1))
self._compile_and_check([admat, bdmat],
[TensorDot(axes)(admat, bdmat)],
[admat_val, bdmat_val], TensorDot)
admat_val = rand(5, 4)
bdmat_val = rand(4, 5)
axes = ((1, 0), (0, 1))
self._compile_and_check([admat, bdmat],
[TensorDot(axes)(admat, bdmat)],
[admat_val, bdmat_val], TensorDot)
adtens3 = dtensor3()
admat_val = rand(5, 4)
adtens3_val = rand(5, 4, 3)
axes = 2
self._compile_and_check([admat, adtens3],
[TensorDot(axes)(admat, adtens3)],
[admat_val, adtens3_val], TensorDot)
adtens3_val = rand(4, 5, 3)
axes = ((1, 0), (0, 1))
self._compile_and_check([admat, adtens3],
[TensorDot(axes)(admat, adtens3)],
[admat_val, adtens3_val], TensorDot)
adtens3_val = rand(4, 3, 5)
axes = ((1, 0), (0, 2))
self._compile_and_check([admat, adtens3],
[TensorDot(axes)(admat, adtens3)],
[admat_val, adtens3_val], TensorDot)
adtens4 = dtensor4()
admat_val = rand(5, 4)
adtens4_val = rand(5, 4, 3, 2)
axes = 2
self._compile_and_check([admat, adtens4],
[TensorDot(axes)(admat, adtens4)],
[admat_val, adtens4_val], TensorDot)
adtens4_val = rand(4, 3, 2, 5)
axes = ((1, 0), (0, 3))
self._compile_and_check([admat, adtens4],
[TensorDot(axes)(admat, adtens4)],
[admat_val, adtens4_val], TensorDot)
# Flatten # Flatten
atens3 = tensor3() atens3 = tensor3()
atens3_val = rand(4, 5, 3) atens3_val = rand(4, 5, 3)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论