提交 e35e76ab authored 作者: James Bergstra's avatar James Bergstra

so sad.. backporting with statements in test_autocast

上级 cdd6e33e
...@@ -2224,22 +2224,40 @@ def test_autocast(): ...@@ -2224,22 +2224,40 @@ def test_autocast():
orig_autocast = autocast_float.dtypes orig_autocast = autocast_float.dtypes
# test that autocast_float_as sets the autocast dtype correctly # test that autocast_float_as sets the autocast dtype correctly
with autocast_float_as('float32') as ac: try: #ghetto 2.4 version of with
ac = autocast_float_as('float32')
ac.__enter__()
assert autocast_float.dtypes == ('float32',) assert autocast_float.dtypes == ('float32',)
finally:
ac.__exit__()
assert autocast_float.dtypes == orig_autocast assert autocast_float.dtypes == orig_autocast
with autocast_float_as('float64') as ac: try: #ghetto 2.4 version of with
ac = autocast_float_as('float64')
ac.__enter__()
assert autocast_float.dtypes == ('float64',) assert autocast_float.dtypes == ('float64',)
finally:
ac.__exit__()
assert autocast_float.dtypes == orig_autocast assert autocast_float.dtypes == orig_autocast
# test that we can set it back to something, and nest it # test that we can set it back to something, and nest it
with autocast_float_as('float32') as ac: try: #ghetto 2.4 version of with
ac = autocast_float_as('float32')
ac.__enter__()
assert autocast_float.dtypes == ('float32',) assert autocast_float.dtypes == ('float32',)
with autocast_float_as('float64') as ac: try: #ghetto 2.4 version of with
ac2 = autocast_float_as('float64')
ac2.__enter__()
assert autocast_float.dtypes == ('float64',) assert autocast_float.dtypes == ('float64',)
finally:
ac2.__exit__()
assert autocast_float.dtypes == ('float32',) assert autocast_float.dtypes == ('float32',)
finally:
ac.__exit__()
assert autocast_float.dtypes == orig_autocast assert autocast_float.dtypes == orig_autocast
# test that the autocasting dtype is used correctly in expression-building # test that the autocasting dtype is used correctly in expression-building
with autocast_float_as('float32') as ac: try: #ghetto 2.4 version of with
ac = autocast_float_as('float32')
ac.__enter__()
assert (dvector()+ 1.1).dtype == 'float64' assert (dvector()+ 1.1).dtype == 'float64'
assert (fvector()+ 1.1).dtype == 'float32' assert (fvector()+ 1.1).dtype == 'float32'
assert (fvector()+ numpy.asarray(1.1,dtype='float64')).dtype == 'float64' assert (fvector()+ numpy.asarray(1.1,dtype='float64')).dtype == 'float64'
...@@ -2247,9 +2265,13 @@ def test_autocast(): ...@@ -2247,9 +2265,13 @@ def test_autocast():
assert (dvector()+ 1).dtype == 'float64' assert (dvector()+ 1).dtype == 'float64'
assert (fvector()+ 1).dtype == 'float32' assert (fvector()+ 1).dtype == 'float32'
finally:
ac.__exit__()
# test that the autocasting dtype is used correctly in expression-building # test that the autocasting dtype is used correctly in expression-building
with autocast_float_as('float64') as ac: try: #ghetto 2.4 version of with
ac = autocast_float_as('float64')
ac.__enter__()
assert (dvector()+ 1.1).dtype == 'float64' assert (dvector()+ 1.1).dtype == 'float64'
assert (fvector()+ 1.1).dtype == 'float64' assert (fvector()+ 1.1).dtype == 'float64'
assert (fvector()+ 1.0).dtype == 'float64' assert (fvector()+ 1.0).dtype == 'float64'
...@@ -2258,14 +2280,24 @@ def test_autocast(): ...@@ -2258,14 +2280,24 @@ def test_autocast():
assert (dvector()+ 1).dtype == 'float64' assert (dvector()+ 1).dtype == 'float64'
assert (fvector()+ 1).dtype == 'float32' assert (fvector()+ 1).dtype == 'float32'
finally:
ac.__exit__()
# test that the autocasting dtype is used correctly in expression-building # test that the autocasting dtype is used correctly in expression-building
with autocast_float_as('float32', 'float64') as ac: try: #ghetto 2.4 version of with
ac = autocast_float_as('float32', 'float64')
ac.__enter__()
assert (dvector()+ 1.1).dtype == 'float64' assert (dvector()+ 1.1).dtype == 'float64'
assert (fvector()+ 1.1).dtype == 'float64' assert (fvector()+ 1.1).dtype == 'float64'
assert (fvector()+ 1.0).dtype == 'float32' assert (fvector()+ 1.0).dtype == 'float32'
with autocast_float_as('float64') as ac: try: #ghetto 2.4 version of with
ac2 = autocast_float_as('float64')
ac2.__enter__()
assert (fvector()+ 1.0).dtype == 'float64' assert (fvector()+ 1.0).dtype == 'float64'
finally:
ac2.__exit__()
finally:
ac.__exit__()
if __name__ == '__main__': if __name__ == '__main__':
if len(sys.argv) >= 2 and sys.argv[1] == 'OPT': if len(sys.argv) >= 2 and sys.argv[1] == 'OPT':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论