提交 fe6f980b authored 作者: Frederic Bastien's avatar Frederic Bastien

Upcast TensorConstant of ndarray of bool to uint8 to allow using them in Theano.

上级 5f16fe1f
...@@ -249,6 +249,10 @@ def constant_or_value(x, rtype, name=None, ndim=None, dtype=None): ...@@ -249,6 +249,10 @@ def constant_or_value(x, rtype, name=None, ndim=None, dtype=None):
x_ = autocast_float(x) x_ = autocast_float(x)
elif isinstance(x, numpy.ndarray): elif isinstance(x, numpy.ndarray):
x_ = x x_ = x
# Currently we don't have a bool dtype in Theano
# So we upcast it to uint8 to don't break our interface for constant.
if x.dtype == 'bool':
x_ = numpy.asarray(x_, dtype='uint8')
else: else:
x_ = numpy.asarray(x) x_ = numpy.asarray(x)
......
...@@ -4018,6 +4018,22 @@ class T_get_constant_value(unittest.TestCase): ...@@ -4018,6 +4018,22 @@ class T_get_constant_value(unittest.TestCase):
for j in range(c.value.shape[1]): for j in range(c.value.shape[1]):
assert get_constant_value(c[i,j]) == c.value[i,j] assert get_constant_value(c[i,j]) == c.value[i,j]
class T_as_tensor_variable(unittest.TestCase):
"""
We test that ticket #649 stay fixed.
We should not allow as_tensor_variable to accept True or False
But it should upcast an ndrarray of bool to uint8
"""
def test_bool(self):
self.assertRaises(TypeError, as_tensor_variable, True)
self.assertRaises(TypeError, as_tensor_variable, False)
def test_ndarray_bool(self):
ten = as_tensor_variable(numpy.array([True, False, False, True, True]))
assert ten.type.dtype == 'uint8'
if __name__ == '__main__': if __name__ == '__main__':
if 1: if 1:
unittest.main() unittest.main()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论