提交 e49d3173 authored 作者: James Bergstra's avatar James Bergstra

added is_valid_value to the Type interface

上级 11c7bab8
......@@ -52,6 +52,14 @@ class Scalar(Type):
except Exception, e:
raise TypeError("Could not convert %s (value=%s) to %s" % (type(data), data, self.dtype), e)
def values_eq_enough(self, a, b):
return abs(a - b) / (a+b) < 1e-4
def is_valid_value(self, a):
_a = numpy.asarray(a)
rval = (_a.ndim == 0) and (str(_a.dtype) == self.dtype)
return rval
def __eq__(self, other):
return type(self) == type(other) and other.dtype == self.dtype
......
......@@ -9,6 +9,7 @@ To read about different sparse formats, see U{http://www-users.cs.umn.edu/~saad/
import sys, operator
import numpy
from scipy import sparse
import scipy.sparse
from .. import gof
from .. import tensor
......@@ -185,6 +186,14 @@ class Sparse(gof.Type):
def __repr__(self):
return "Sparse[%s, %s]" % (str(self.dtype), str(self.format))
def values_eq_enough(self, a, b, eps=1e-6):
return scipy.sparse.issparse(a) \
and scipy.sparse.issparse(b) \
and abs(a-b).sum() < (1e-6 * a.nnz)
def is_valid_value(self, a):
return scipy.sparse.issparse(a) and (a.format == self.format)
csc_matrix = Sparse(format='csc')
csr_matrix = Sparse(format='csr')
......
......@@ -226,6 +226,17 @@ class Tensor(Type):
return type(a) is numpy.ndarray and type(b) is numpy.ndarray \
and (a.shape == b.shape) and numpy.allclose(a, b)
def is_valid_value(self, a):
rval = (type(a) is numpy.ndarray) and (self.ndim == a.ndim) \
and (str(a.dtype) == self.dtype) \
and all([((si == 1) or not bi) for si, bi in zip(a.shape, self.broadcastable)])
if not rval:
print type(a),(type(a) is numpy.ndarray)
print a.ndim, (self.ndim == a.ndim)
print a.dtype, (str(a.dtype) == self.dtype)
print a.shape, self.broadcastable, ([(shp_i == 1) for shp_i in a.shape] == self.broadcastable)
return rval
def __hash__(self):
"""Hash equal for same kinds of Tensor"""
return hash(self.dtype) ^ hash(self.broadcastable)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论