提交 365e825c authored 作者: Eric Larsen's avatar Eric Larsen 提交者: Frederic

updates to MaxAndArgmax and IncSubtensor

上级 b1d5f152
...@@ -6208,10 +6208,6 @@ class TestInferShape(utt.InferShapeTester): ...@@ -6208,10 +6208,6 @@ class TestInferShape(utt.InferShapeTester):
# MaxAndArgmax, # MaxAndArgmax,
# Note: axis as a tensor.iscalar or constant conflicts with # Note: axis as a tensor.iscalar or constant conflicts with
# make_node in basic # make_node in basic
adtens3 = dtensor3()
aiscal = iscalar()
aconst = 1
aiscal_val = randint(0, 2, size=())
adtens3_val = rand(4, 5, 3) adtens3_val = rand(4, 5, 3)
self._compile_and_check([adtens3], self._compile_and_check([adtens3],
MaxAndArgmax()(adtens3, None), MaxAndArgmax()(adtens3, None),
...@@ -6279,12 +6275,13 @@ class TestInferShape(utt.InferShapeTester): ...@@ -6279,12 +6275,13 @@ class TestInferShape(utt.InferShapeTester):
""" """
# IncSubtensor # IncSubtensor
# TODO: populate with tensors of varying dimensions # Note: Is testing only for the 4-tensor below sufficient?
# Please determine and take action.
admat = dmatrix() admat = dmatrix()
bdmat = dmatrix() bdmat = dmatrix()
advec = dvector() advec = dvector()
adscal = dscalar() adscal = dscalar()
admat_val = rand(4, 4) admat_val = rand(5, 4)
self._compile_and_check([admat, bdmat], self._compile_and_check([admat, bdmat],
[inc_subtensor(admat[2:4], bdmat)], [inc_subtensor(admat[2:4], bdmat)],
[admat_val, [[1, 2, 3, 4]]], IncSubtensor) [admat_val, [[1, 2, 3, 4]]], IncSubtensor)
...@@ -6297,8 +6294,12 @@ class TestInferShape(utt.InferShapeTester): ...@@ -6297,8 +6294,12 @@ class TestInferShape(utt.InferShapeTester):
[inc_subtensor(admat[2, 3], adscal)], [inc_subtensor(admat[2, 3], adscal)],
[admat_val, 1], IncSubtensor) [admat_val, 1], IncSubtensor)
self._compile_and_check([admat, adscal],
[inc_subtensor(admat[1:3, 2], adscal)],
[admat_val, 1], IncSubtensor)
self._compile_and_check([admat, bdmat], self._compile_and_check([admat, bdmat],
[set_subtensor(admat[2:3], bdmat)], [set_subtensor(admat[2:4], bdmat)],
[admat_val, [[1, 2, 3, 4]]], IncSubtensor) [admat_val, [[1, 2, 3, 4]]], IncSubtensor)
self._compile_and_check([admat, advec], self._compile_and_check([admat, advec],
...@@ -6306,9 +6307,47 @@ class TestInferShape(utt.InferShapeTester): ...@@ -6306,9 +6307,47 @@ class TestInferShape(utt.InferShapeTester):
[admat_val, [1, 2, 3, 4]], IncSubtensor) [admat_val, [1, 2, 3, 4]], IncSubtensor)
self._compile_and_check([admat, adscal], self._compile_and_check([admat, adscal],
[set_subtensor(admat[2, ], adscal)], [set_subtensor(admat[2, 3], adscal)],
[admat_val, 1], IncSubtensor)
self._compile_and_check([admat, adscal],
[set_subtensor(admat[1:3, 2], adscal)],
[admat_val, 1], IncSubtensor) [admat_val, 1], IncSubtensor)
bdtens4 = dtensor4()
adtens4_val = rand(3, 4, 2, 5)
self._compile_and_check([adtens4, bdtens4],
[inc_subtensor(adtens4[::, 2:4, ::, ::], bdtens4)],
[adtens4_val, [[[[1, 2, 3, 4, 5]]]]], IncSubtensor)
self._compile_and_check([adtens4, bdmat],
[inc_subtensor(adtens4[2, 2:4, 1, ::], bdmat)],
[adtens4_val, [[1, 2, 3, 4, 5]]], IncSubtensor)
self._compile_and_check([adtens4, advec],
[inc_subtensor(adtens4[0, 1, ::, 4], advec)],
[adtens4_val, [1, 2]], IncSubtensor)
self._compile_and_check([adtens4, adscal],
[inc_subtensor(adtens4[1:3, 1, ::, 2:4], adscal)],
[adtens4_val, 1], IncSubtensor)
self._compile_and_check([adtens4, bdtens4],
[set_subtensor(adtens4[::, 2:4, ::, ::], bdtens4)],
[adtens4_val, [[[[1, 2, 3, 4, 5]]]]], IncSubtensor)
self._compile_and_check([adtens4, bdmat],
[set_subtensor(adtens4[2, 2:4, 1, ::], bdmat)],
[adtens4_val, [[1, 2, 3, 4, 5]]], IncSubtensor)
self._compile_and_check([adtens4, advec],
[set_subtensor(adtens4[0, 1, ::, 4], advec)],
[adtens4_val, [1, 2]], IncSubtensor)
self._compile_and_check([adtens4, adscal],
[set_subtensor(adtens4[1:3, 1, ::, 2:4], adscal)],
[adtens4_val, 1], IncSubtensor)
# AdvancedIncSubtensor1 # AdvancedIncSubtensor1
# TODO: populate with tensors and lists of varying dimensions and lengths # TODO: populate with tensors and lists of varying dimensions and lengths
admat = dmatrix() admat = dmatrix()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论