提交 fdb0a71b authored 作者: Olivier Breuleux's avatar Olivier Breuleux

error checking for the dtype

上级 8ee3b6d6
...@@ -41,6 +41,7 @@ class BaseTensor(ResultBase): ...@@ -41,6 +41,7 @@ class BaseTensor(ResultBase):
# constructor that works with an ndarray. # constructor that works with an ndarray.
ResultBase.__init__(self, role=role, name=name) ResultBase.__init__(self, role=role, name=name)
self._dtype = str(dtype) self._dtype = str(dtype)
self.dtype_specs() # this is just for error checking
self._broadcastable = tuple(broadcastable) self._broadcastable = tuple(broadcastable)
###################### ######################
...@@ -78,7 +79,10 @@ class BaseTensor(ResultBase): ...@@ -78,7 +79,10 @@ class BaseTensor(ResultBase):
""" """
#TODO: add more type correspondances for e.g. int32, int64, float32, #TODO: add more type correspondances for e.g. int32, int64, float32,
#complex64, etc. #complex64, etc.
return {'float64': (float, 'double', 'NPY_DOUBLE')}[self.dtype] try:
return {'float64': (float, 'double', 'NPY_DOUBLE')}[self.dtype]
except KeyError:
raise TypeError("Unsupported dtype for BaseTensor: %s" % self.dtype)
# #
# Hash for constant folding # Hash for constant folding
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论