提交 9ffec773 authored 作者: --global's avatar --global

Add tests for broadcastable pattern of flatten output

上级 4d659548
...@@ -5100,6 +5100,31 @@ def test_flatten_outdim2_of_3(): ...@@ -5100,6 +5100,31 @@ def test_flatten_outdim2_of_3():
utt.verify_grad(Flatten(2), [a_val]) utt.verify_grad(Flatten(2), [a_val])
def test_flatten_broadcastable():
# Ensure that the broadcastable pattern of the output is coherent with
# that of the input
inp = TensorType('float64', (False, False, False, False))()
out = flatten(inp, outdim=2)
assert out.broadcastable == (False, False)
inp = TensorType('float64', (False, False, False, True))()
out = flatten(inp, outdim=2)
assert out.broadcastable == (False, False)
inp = TensorType('float64', (False, True, False, True))()
out = flatten(inp, outdim=2)
assert out.broadcastable == (False, False)
inp = TensorType('float64', (False, True, True, True))()
out = flatten(inp, outdim=2)
assert out.broadcastable == (False, True)
inp = TensorType('float64', (True, False, True, True))()
out = flatten(inp, outdim=3)
assert out.broadcastable == (True, False, True)
def test_flatten_outdim_invalid(): def test_flatten_outdim_invalid():
a = dmatrix() a = dmatrix()
try: try:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论