提交 1e55dc82 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

improved error messages to filter in scalar and tensor

上级 2ccbf402
...@@ -45,8 +45,12 @@ class Scalar(Type): ...@@ -45,8 +45,12 @@ class Scalar(Type):
def filter(self, data, strict = False): def filter(self, data, strict = False):
py_type = self.dtype_specs()[0] py_type = self.dtype_specs()[0]
if strict: assert isinstance(data, py_type) if strict and not isinstance(data, py_type):
return py_type(data) raise TypeError("%s expected a %s" % (self, self.dtype), data)
try:
return py_type(data)
except Exception, e:
raise TypeError("Could not convert to %s" % self.dtype, e)
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) and other.dtype == self.dtype return type(self) == type(other) and other.dtype == self.dtype
......
...@@ -12,7 +12,7 @@ import gof ...@@ -12,7 +12,7 @@ import gof
import blas # for gemm, dot import blas # for gemm, dot
import gradient import gradient
import elemwise as s2t import elemwise
import scalar as scal import scalar as scal
from gof.python25 import partial from gof.python25 import partial
...@@ -91,14 +91,19 @@ class Tensor(Type): ...@@ -91,14 +91,19 @@ class Tensor(Type):
def filter(self, data, strict = False): def filter(self, data, strict = False):
_data = data _data = data
if strict: if strict:
assert isinstance(data, numpy.ndarray) if not isinstance(data, numpy.ndarray):
assert str(data.dtype) == self.dtype raise TypeError("%s expected a ndarray object.", data, type(data))
if not str(data.dtype) == self.dtype:
raise TypeError("%s expected a ndarray object with dtype = %s (got %s)." % (self, self.dtype, data.dtype))
if not data.ndim == self.ndim:
raise TypeError("%s expected a ndarray object with %s dimensions (got %s)." % (self, self.ndim, data.ndim))
return data
else: else:
data = numpy.asarray(data, dtype = self.dtype) data = numpy.asarray(data, dtype = self.dtype)
if not self.ndim == data.ndim: if not self.ndim == data.ndim:
raise TypeError("Wrong number of dimensions: expected %s, got %s." % (self.ndim, data.ndim), _data) raise TypeError("Wrong number of dimensions: expected %s, got %s." % (self.ndim, data.ndim), _data)
if any(b and d != 1 for d, b in zip(data.shape, self.broadcastable)): if any(b and d != 1 for d, b in zip(data.shape, self.broadcastable)):
raise ValueError("Non-unit value on shape on a broadcastable dimension.", data.shape, self.broadcastable) raise TypeError("Non-unit value on shape on a broadcastable dimension.", data.shape, self.broadcastable)
return data return data
def dtype_specs(self): def dtype_specs(self):
...@@ -392,11 +397,11 @@ class TensorConstant(Constant, _tensor_py_operators): ...@@ -392,11 +397,11 @@ class TensorConstant(Constant, _tensor_py_operators):
class TensorValue(Value, _tensor_py_operators): class TensorValue(Value, _tensor_py_operators):
pass pass
s2t.as_tensor = as_tensor elemwise.as_tensor = as_tensor
s2t.Tensor = Tensor elemwise.Tensor = Tensor
s2t.TensorResult = TensorResult elemwise.TensorResult = TensorResult
s2t.TensorConstant = TensorConstant elemwise.TensorConstant = TensorConstant
s2t.TensorValue = TensorValue elemwise.TensorValue = TensorValue
...@@ -405,9 +410,9 @@ s2t.TensorValue = TensorValue ...@@ -405,9 +410,9 @@ s2t.TensorValue = TensorValue
######################### #########################
def _elemwise(scalar_op, name): def _elemwise(scalar_op, name):
straight = s2t.Elemwise(scalar_op) straight = elemwise.Elemwise(scalar_op, name = name)
inplace_scalar_op = scalar_op.__class__(scal.transfer_type(0)) inplace_scalar_op = scalar_op.__class__(scal.transfer_type(0))
inplace = s2t.Elemwise(inplace_scalar_op, {0: 0}, name = name) inplace = elemwise.Elemwise(inplace_scalar_op, {0: 0}, name = name+"_inplace")
return straight, inplace return straight, inplace
...@@ -454,14 +459,14 @@ def cast(t, dtype): ...@@ -454,14 +459,14 @@ def cast(t, dtype):
'complex128': convert_to_complex128} 'complex128': convert_to_complex128}
return mapping[dtype](t) return mapping[dtype](t)
convert_to_int8 = s2t.Elemwise(scal.Identity(scal.specific_out(scal.int8))) convert_to_int8 = elemwise.Elemwise(scal.Identity(scal.specific_out(scal.int8)))
convert_to_int16 = s2t.Elemwise(scal.Identity(scal.specific_out(scal.int16))) convert_to_int16 = elemwise.Elemwise(scal.Identity(scal.specific_out(scal.int16)))
convert_to_int32 = s2t.Elemwise(scal.Identity(scal.specific_out(scal.int32))) convert_to_int32 = elemwise.Elemwise(scal.Identity(scal.specific_out(scal.int32)))
convert_to_int64 = s2t.Elemwise(scal.Identity(scal.specific_out(scal.int64))) convert_to_int64 = elemwise.Elemwise(scal.Identity(scal.specific_out(scal.int64)))
convert_to_float32 = s2t.Elemwise(scal.Identity(scal.specific_out(scal.float32))) convert_to_float32 = elemwise.Elemwise(scal.Identity(scal.specific_out(scal.float32)))
convert_to_float64 = s2t.Elemwise(scal.Identity(scal.specific_out(scal.float64))) convert_to_float64 = elemwise.Elemwise(scal.Identity(scal.specific_out(scal.float64)))
convert_to_complex64 = s2t.Elemwise(scal.Identity(scal.specific_out(scal.complex64))) convert_to_complex64 = elemwise.Elemwise(scal.Identity(scal.specific_out(scal.complex64)))
convert_to_complex128 = s2t.Elemwise(scal.Identity(scal.specific_out(scal.complex128))) convert_to_complex128 = elemwise.Elemwise(scal.Identity(scal.specific_out(scal.complex128)))
...@@ -580,10 +585,10 @@ def ones_like(model): ...@@ -580,10 +585,10 @@ def ones_like(model):
def zeros_like(model): def zeros_like(model):
return fill(model, 0.0) return fill(model, 0.0)
tensor_copy = s2t.Elemwise(scal.identity) tensor_copy = elemwise.Elemwise(scal.identity)
def sum(input, axis = None): def sum(input, axis = None):
return s2t.Sum(axis)(input) return elemwise.Sum(axis)(input)
########################## ##########################
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论