提交 5d515fa9 authored 作者: Eric Larsen's avatar Eric Larsen 提交者: Frederic

Join: added test on second axis; TensorDot: added axes ((0, 1), (1, 0));…

Join: added test on second axis; TensorDot: added axes ((0, 1), (1, 0)); MaxAndArgmax: removed TODO comment
上级 4e6ef4aa
...@@ -6135,6 +6135,11 @@ class TestInferShape(utt.InferShapeTester): ...@@ -6135,6 +6135,11 @@ class TestInferShape(utt.InferShapeTester):
[TensorDot(axes)(admat, bdmat)], [TensorDot(axes)(admat, bdmat)],
[admat_val, bdmat_val], TensorDot) [admat_val, bdmat_val], TensorDot)
axes = ((0, 1))
self._compile_and_check([admat, bdmat],
[TensorDot(axes)(admat, bdmat)],
[admat_val, bdmat_val], TensorDot)
admat_val = rand(5, 4) admat_val = rand(5, 4)
bdmat_val = rand(4, 5) bdmat_val = rand(4, 5)
axes = ((1,), (0,)) axes = ((1,), (0,))
...@@ -6142,6 +6147,11 @@ class TestInferShape(utt.InferShapeTester): ...@@ -6142,6 +6147,11 @@ class TestInferShape(utt.InferShapeTester):
[TensorDot(axes)(admat, bdmat)], [TensorDot(axes)(admat, bdmat)],
[admat_val, bdmat_val], TensorDot) [admat_val, bdmat_val], TensorDot)
axes = ((0, 1), (1, 0))
self._compile_and_check([admat, bdmat],
[TensorDot(axes)(admat, bdmat)],
[admat_val, bdmat_val], TensorDot)
admat_val = rand(5, 4) admat_val = rand(5, 4)
bdmat_val = rand(5, 4) bdmat_val = rand(5, 4)
axes = ((0, 1), (0, 1)) axes = ((0, 1), (0, 1))
...@@ -6273,6 +6283,14 @@ class TestInferShape(utt.InferShapeTester): ...@@ -6273,6 +6283,14 @@ class TestInferShape(utt.InferShapeTester):
[Join()(aiscal, admat, bdmat, cdmat)], [Join()(aiscal, admat, bdmat, cdmat)],
[aiscal_val, admat_val, bdmat_val, cdmat_val], Join) [aiscal_val, admat_val, bdmat_val, cdmat_val], Join)
admat_val = rand(4, 1)
bdmat_val = rand(4, 3)
cdmat_val = rand(4, 2)
aiscal_val = 1
self._compile_and_check([aiscal, admat, bdmat, cdmat],
[Join()(aiscal, admat, bdmat, cdmat)],
[aiscal_val, admat_val, bdmat_val, cdmat_val], Join)
# PermuteRowElements # PermuteRowElements
abool = True abool = True
rng = numpy.random.RandomState(utt.fetch_seed()) rng = numpy.random.RandomState(utt.fetch_seed())
...@@ -6363,8 +6381,6 @@ class TestInferShape(utt.InferShapeTester): ...@@ -6363,8 +6381,6 @@ class TestInferShape(utt.InferShapeTester):
ciscal_val, discal_val], Alloc) ciscal_val, discal_val], Alloc)
# MaxAndArgmax, # MaxAndArgmax,
# Note: axis as a tensor.iscalar or constant conflicts with
# make_node in basic
adtens3_val = rand(4, 5, 3) adtens3_val = rand(4, 5, 3)
self._compile_and_check([adtens3], self._compile_and_check([adtens3],
MaxAndArgmax()(adtens3, None), MaxAndArgmax()(adtens3, None),
...@@ -6602,7 +6618,6 @@ MakeVector [@A] '' 3 ...@@ -6602,7 +6618,6 @@ MakeVector [@A] '' 3
|AdvancedIncSubtensor{Finplace=ainplace=linplace=sinplace=e, T set_instead_of_incr set_instead_of_incu set_instead_of_ince} [@C] '' 0 |AdvancedIncSubtensor{Finplace=ainplace=linplace=sinplace=e, T set_instead_of_incr set_instead_of_incu set_instead_of_ince} [@C] '' 0
remaining op as a class: AdvancedIncSubtensor{Finplace=ainplace=linplace=sinplace=e, T set_instead_of_incr set_instead_of_incu set_instead_of_ince} remaining op as a class: AdvancedIncSubtensor{Finplace=ainplace=linplace=sinplace=e, T set_instead_of_incr set_instead_of_incu set_instead_of_ince}
(5, 4) [5 4] (5, 4) [5 4]
"""
aivec_val = [1, 3, 2] aivec_val = [1, 3, 2]
bivec_val = [0, 3, 3] bivec_val = [0, 3, 3]
...@@ -6610,6 +6625,7 @@ remaining op as a class: AdvancedIncSubtensor{Finplace=ainplace=linplace=sinplac ...@@ -6610,6 +6625,7 @@ remaining op as a class: AdvancedIncSubtensor{Finplace=ainplace=linplace=sinplac
self._compile_and_check([admat, advec], self._compile_and_check([admat, advec],
[set_subtensor(admat[aivec_val, bivec_val], advec)], [set_subtensor(admat[aivec_val, bivec_val], advec)],
[admat_val, advec_val], AdvancedIncSubtensor) [admat_val, advec_val], AdvancedIncSubtensor)
"""
# Reshape # Reshape
# TODO: The shape is apparently generated correctly but the final result is abnormal: # TODO: The shape is apparently generated correctly but the final result is abnormal:
...@@ -6628,7 +6644,7 @@ MakeVector [@A] '' 3 ...@@ -6628,7 +6644,7 @@ MakeVector [@A] '' 3
remaining op as a class: Reshape{2} remaining op as a class: Reshape{2}
shapes generated: shapes generated:
(4, 3) [4 3] (4, 3) [4 3]
"""
admat = dmatrix() admat = dmatrix()
ndim = 2 ndim = 2
...@@ -6661,7 +6677,7 @@ shapes generated: ...@@ -6661,7 +6677,7 @@ shapes generated:
self._compile_and_check([adtens4], self._compile_and_check([adtens4],
[tile(adtens4, aivec_val, ndim)], [tile(adtens4, aivec_val, ndim)],
[adtens4_val], Tile) [adtens4_val], Tile)
"""
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论