提交 1e55dc82 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

improved error messages to filter in scalar and tensor

上级 2ccbf402
......@@ -45,8 +45,12 @@ class Scalar(Type):
def filter(self, data, strict = False):
py_type = self.dtype_specs()[0]
if strict: assert isinstance(data, py_type)
return py_type(data)
if strict and not isinstance(data, py_type):
raise TypeError("%s expected a %s" % (self, self.dtype), data)
try:
return py_type(data)
except Exception, e:
raise TypeError("Could not convert to %s" % self.dtype, e)
def __eq__(self, other):
return type(self) == type(other) and other.dtype == self.dtype
......
......@@ -12,7 +12,7 @@ import gof
import blas # for gemm, dot
import gradient
import elemwise as s2t
import elemwise
import scalar as scal
from gof.python25 import partial
......@@ -91,14 +91,19 @@ class Tensor(Type):
def filter(self, data, strict = False):
_data = data
if strict:
assert isinstance(data, numpy.ndarray)
assert str(data.dtype) == self.dtype
if not isinstance(data, numpy.ndarray):
raise TypeError("%s expected a ndarray object.", data, type(data))
if not str(data.dtype) == self.dtype:
raise TypeError("%s expected a ndarray object with dtype = %s (got %s)." % (self, self.dtype, data.dtype))
if not data.ndim == self.ndim:
raise TypeError("%s expected a ndarray object with %s dimensions (got %s)." % (self, self.ndim, data.ndim))
return data
else:
data = numpy.asarray(data, dtype = self.dtype)
if not self.ndim == data.ndim:
raise TypeError("Wrong number of dimensions: expected %s, got %s." % (self.ndim, data.ndim), _data)
if any(b and d != 1 for d, b in zip(data.shape, self.broadcastable)):
raise ValueError("Non-unit value on shape on a broadcastable dimension.", data.shape, self.broadcastable)
raise TypeError("Non-unit value on shape on a broadcastable dimension.", data.shape, self.broadcastable)
return data
def dtype_specs(self):
......@@ -392,11 +397,11 @@ class TensorConstant(Constant, _tensor_py_operators):
class TensorValue(Value, _tensor_py_operators):
pass
s2t.as_tensor = as_tensor
s2t.Tensor = Tensor
s2t.TensorResult = TensorResult
s2t.TensorConstant = TensorConstant
s2t.TensorValue = TensorValue
elemwise.as_tensor = as_tensor
elemwise.Tensor = Tensor
elemwise.TensorResult = TensorResult
elemwise.TensorConstant = TensorConstant
elemwise.TensorValue = TensorValue
......@@ -405,9 +410,9 @@ s2t.TensorValue = TensorValue
#########################
def _elemwise(scalar_op, name):
straight = s2t.Elemwise(scalar_op)
straight = elemwise.Elemwise(scalar_op, name = name)
inplace_scalar_op = scalar_op.__class__(scal.transfer_type(0))
inplace = s2t.Elemwise(inplace_scalar_op, {0: 0}, name = name)
inplace = elemwise.Elemwise(inplace_scalar_op, {0: 0}, name = name+"_inplace")
return straight, inplace
......@@ -454,14 +459,14 @@ def cast(t, dtype):
'complex128': convert_to_complex128}
return mapping[dtype](t)
convert_to_int8 = s2t.Elemwise(scal.Identity(scal.specific_out(scal.int8)))
convert_to_int16 = s2t.Elemwise(scal.Identity(scal.specific_out(scal.int16)))
convert_to_int32 = s2t.Elemwise(scal.Identity(scal.specific_out(scal.int32)))
convert_to_int64 = s2t.Elemwise(scal.Identity(scal.specific_out(scal.int64)))
convert_to_float32 = s2t.Elemwise(scal.Identity(scal.specific_out(scal.float32)))
convert_to_float64 = s2t.Elemwise(scal.Identity(scal.specific_out(scal.float64)))
convert_to_complex64 = s2t.Elemwise(scal.Identity(scal.specific_out(scal.complex64)))
convert_to_complex128 = s2t.Elemwise(scal.Identity(scal.specific_out(scal.complex128)))
convert_to_int8 = elemwise.Elemwise(scal.Identity(scal.specific_out(scal.int8)))
convert_to_int16 = elemwise.Elemwise(scal.Identity(scal.specific_out(scal.int16)))
convert_to_int32 = elemwise.Elemwise(scal.Identity(scal.specific_out(scal.int32)))
convert_to_int64 = elemwise.Elemwise(scal.Identity(scal.specific_out(scal.int64)))
convert_to_float32 = elemwise.Elemwise(scal.Identity(scal.specific_out(scal.float32)))
convert_to_float64 = elemwise.Elemwise(scal.Identity(scal.specific_out(scal.float64)))
convert_to_complex64 = elemwise.Elemwise(scal.Identity(scal.specific_out(scal.complex64)))
convert_to_complex128 = elemwise.Elemwise(scal.Identity(scal.specific_out(scal.complex128)))
......@@ -580,10 +585,10 @@ def ones_like(model):
def zeros_like(model):
return fill(model, 0.0)
tensor_copy = s2t.Elemwise(scal.identity)
tensor_copy = elemwise.Elemwise(scal.identity)
def sum(input, axis = None):
return s2t.Sum(axis)(input)
return elemwise.Sum(axis)(input)
##########################
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论