提交 a5a68aaf authored 作者: Frederic Bastien's avatar Frederic Bastien

update tests doc why we don't test multinomal infer_shape.

上级 d001aa7e
...@@ -922,10 +922,12 @@ class T_random_function(utt.InferShapeTester): ...@@ -922,10 +922,12 @@ class T_random_function(utt.InferShapeTester):
self._compile_and_check([rng_R], [out], [rng_R_val], self._compile_and_check([rng_R], [out], [rng_R_val],
RandomFunction) RandomFunction)
"""
#infer_shape don't work for multinomial.
#The parameter ndim_added is set to 1 and in this case, the infer_shape
#inplementation don't know how to infer the shape
post_r, out = multinomial(rng_R) post_r, out = multinomial(rng_R)
# ERROR: 'graph contains cycles'
"""
self._compile_and_check([rng_R], [out], [rng_R_val], self._compile_and_check([rng_R], [out], [rng_R_val],
RandomFunction) RandomFunction)
""" """
...@@ -943,6 +945,8 @@ class T_random_function(utt.InferShapeTester): ...@@ -943,6 +945,8 @@ class T_random_function(utt.InferShapeTester):
RandomFunction) RandomFunction)
# multinomial, specified shape # multinomial, specified shape
"""
#infer_shape don't work for multinomial
n = iscalar() n = iscalar()
pvals = dvector() pvals = dvector()
size_val = (7, 3) size_val = (7, 3)
...@@ -951,9 +955,6 @@ class T_random_function(utt.InferShapeTester): ...@@ -951,9 +955,6 @@ class T_random_function(utt.InferShapeTester):
post_r, out = multinomial(rng_R, size=size_val, n=n, pvals=pvals, post_r, out = multinomial(rng_R, size=size_val, n=n, pvals=pvals,
ndim=2) ndim=2)
# ERROR: 'graph contains cycles'
# see NOTE 1 below
"""
self._compile_and_check([rng_R, n, pvals], [out], self._compile_and_check([rng_R, n, pvals], [out],
[rng_R_val, n_val, pvals_val], [rng_R_val, n_val, pvals_val],
RandomFunction) RandomFunction)
...@@ -1054,14 +1055,14 @@ class T_random_function(utt.InferShapeTester): ...@@ -1054,14 +1055,14 @@ class T_random_function(utt.InferShapeTester):
std_val], RandomFunction) std_val], RandomFunction)
# multinomial with tensor-3 probabilities # multinomial with tensor-3 probabilities
"""
#multinomial infer_shape don't work.
pvals = dtensor3() pvals = dtensor3()
n = iscalar() n = iscalar()
post_r, out = multinomial(rng_R, n=n, pvals=pvals, size=(1, -1)) post_r, out = multinomial(rng_R, n=n, pvals=pvals, size=(1, -1))
pvals_val = [[[.1, .9], [.2, .8], [.3, .7]]] pvals_val = [[[.1, .9], [.2, .8], [.3, .7]]]
n_val = 9 n_val = 9
# ERROR: 'graph contains cycles'
"""
self._compile_and_check([rng_R, n, pvals], [out], self._compile_and_check([rng_R, n, pvals], [out],
[rng_R_val, n_val, [rng_R_val, n_val,
pvals_val], RandomFunction) pvals_val], RandomFunction)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论