提交 e49d3173 authored 作者: James Bergstra's avatar James Bergstra

added is_valid_value to the Type interface

上级 11c7bab8
...@@ -52,6 +52,14 @@ class Scalar(Type): ...@@ -52,6 +52,14 @@ class Scalar(Type):
except Exception, e: except Exception, e:
raise TypeError("Could not convert %s (value=%s) to %s" % (type(data), data, self.dtype), e) raise TypeError("Could not convert %s (value=%s) to %s" % (type(data), data, self.dtype), e)
def values_eq_enough(self, a, b):
return abs(a - b) / (a+b) < 1e-4
def is_valid_value(self, a):
_a = numpy.asarray(a)
rval = (_a.ndim == 0) and (str(_a.dtype) == self.dtype)
return rval
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) and other.dtype == self.dtype return type(self) == type(other) and other.dtype == self.dtype
......
...@@ -9,6 +9,7 @@ To read about different sparse formats, see U{http://www-users.cs.umn.edu/~saad/ ...@@ -9,6 +9,7 @@ To read about different sparse formats, see U{http://www-users.cs.umn.edu/~saad/
import sys, operator import sys, operator
import numpy import numpy
from scipy import sparse from scipy import sparse
import scipy.sparse
from .. import gof from .. import gof
from .. import tensor from .. import tensor
...@@ -185,6 +186,14 @@ class Sparse(gof.Type): ...@@ -185,6 +186,14 @@ class Sparse(gof.Type):
def __repr__(self): def __repr__(self):
return "Sparse[%s, %s]" % (str(self.dtype), str(self.format)) return "Sparse[%s, %s]" % (str(self.dtype), str(self.format))
def values_eq_enough(self, a, b, eps=1e-6):
return scipy.sparse.issparse(a) \
and scipy.sparse.issparse(b) \
and abs(a-b).sum() < (1e-6 * a.nnz)
def is_valid_value(self, a):
return scipy.sparse.issparse(a) and (a.format == self.format)
csc_matrix = Sparse(format='csc') csc_matrix = Sparse(format='csc')
csr_matrix = Sparse(format='csr') csr_matrix = Sparse(format='csr')
......
...@@ -226,6 +226,17 @@ class Tensor(Type): ...@@ -226,6 +226,17 @@ class Tensor(Type):
return type(a) is numpy.ndarray and type(b) is numpy.ndarray \ return type(a) is numpy.ndarray and type(b) is numpy.ndarray \
and (a.shape == b.shape) and numpy.allclose(a, b) and (a.shape == b.shape) and numpy.allclose(a, b)
def is_valid_value(self, a):
rval = (type(a) is numpy.ndarray) and (self.ndim == a.ndim) \
and (str(a.dtype) == self.dtype) \
and all([((si == 1) or not bi) for si, bi in zip(a.shape, self.broadcastable)])
if not rval:
print type(a),(type(a) is numpy.ndarray)
print a.ndim, (self.ndim == a.ndim)
print a.dtype, (str(a.dtype) == self.dtype)
print a.shape, self.broadcastable, ([(shp_i == 1) for shp_i in a.shape] == self.broadcastable)
return rval
def __hash__(self): def __hash__(self):
"""Hash equal for same kinds of Tensor""" """Hash equal for same kinds of Tensor"""
return hash(self.dtype) ^ hash(self.broadcastable) return hash(self.dtype) ^ hash(self.broadcastable)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论