提交 0488b3cf authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Fixed a test when cast_policy == numpy+floatX and floatX == float32

上级 a00d98ff
...@@ -10,7 +10,7 @@ except ImportError: ...@@ -10,7 +10,7 @@ except ImportError:
pass#the variable enable_sparse will be used to disable the test file. pass#the variable enable_sparse will be used to disable the test file.
import theano import theano
from theano import compile from theano import compile, config
from theano.sparse import enable_sparse from theano.sparse import enable_sparse
if enable_sparse == False: if enable_sparse == False:
raise SkipTest('Optional package sparse disabled') raise SkipTest('Optional package sparse disabled')
...@@ -239,8 +239,18 @@ class T_AddMul(unittest.TestCase): ...@@ -239,8 +239,18 @@ class T_AddMul(unittest.TestCase):
self.assertRaises(NotImplementedError, add, a_sv, c_dv) self.assertRaises(NotImplementedError, add, a_sv, c_dv)
self.assertRaises(NotImplementedError, add, c_sv, a_dv) self.assertRaises(NotImplementedError, add, c_sv, a_dv)
# mul upcasts the dense input if needed # mul may upcast the dense input if needed
self.assertRaises(NotImplementedError, mul, a_sv, b_dv) if (config.cast_policy in ('custom', 'numpy') or
(config.cast_policy == 'numpy+floatX' and
config.floatX == 'float64')):
# The result should be a float64 (not implemented).
self.assertRaises(NotImplementedError, mul, a_sv, b_dv)
elif (config.cast_policy == 'numpy+floatX' and
config.floatX == 'float32'):
# The result should be a float32.
assert mul(a_sv, b_dv).dtype == 'float32'
else:
raise NotImplementedError()
self.assertRaises(NotImplementedError, mul, b_sv, a_dv) self.assertRaises(NotImplementedError, mul, b_sv, a_dv)
assert mul(b_sv, c_dv).dtype == 'int32' assert mul(b_sv, c_dv).dtype == 'int32'
self.assertRaises(NotImplementedError, mul, c_sv, b_dv) self.assertRaises(NotImplementedError, mul, c_sv, b_dv)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论