bugs fixed toward replacing NumpyR

上级 45f56231
...@@ -39,16 +39,18 @@ def as_string(*rs): ...@@ -39,16 +39,18 @@ def as_string(*rs):
def print_graph(*rs): def print_graph(*rs):
print as_string(*rs) print as_string(*rs)
#useful mostly for unit tests
def _approx_eq(a,b,eps=1.0e-9): def _approx_eq(a,b,eps=1.0e-9):
a = numpy.asarray(a) a = numpy.asarray(a)
b = numpy.asarray(b) b = numpy.asarray(b)
if a.shape != b.shape: if a.shape != b.shape:
return False return False
d = abs(a-b) return numpy.max( numpy.abs(a-b)) < eps
return numpy.all(d < eps)
@blas._constant # TODO: move this decorator to a utility script # This function is only executed the first time it is called, subsequent calls
# return immediately from a cache of the first return value
@blas._constant # TODO: move this decorator to a utility file
def _compile_dir(): def _compile_dir():
"""Return the directory in which scipy.weave should store code objects. """Return the directory in which scipy.weave should store code objects.
...@@ -82,73 +84,256 @@ def _compile_dir(): ...@@ -82,73 +84,256 @@ def _compile_dir():
sys.path.append(cachedir) sys.path.append(cachedir)
return cachedir return cachedir
class ResultBase:
"""Base class for storing Op inputs and outputs
Attributes:
_role - None or (owner, index) or BrokenLink
_data - anything
constant - Boolean
Properties:
role - (rw)
owner - (ro)
index - (ro)
data - (rw)
replaced - (rw) : True iff _role is BrokenLink
Abstract Methods:
data_filter
"""
class BrokenLink:
"""The owner of a Result that was replaced by another Result"""
__slots__ = ['old_role']
def __init__(self, role): self.old_role = role
def __nonzero__(self): return False
class BrokenLinkError(Exception):
"""Exception thrown when an owner is a BrokenLink"""
class AbstractFunction(Exception):
"""Exception thrown when an abstract function is called"""
__slots__ = ['_role', '_data', 'constant']
def __init__(self, role=None, data=None, constant=False):
self._role = role
self.constant = constant
if data is None: #None is not filtered
self._data = None
else:
try:
self._data = self.data_filter(data)
except ResultBase.AbstractFunction:
self._data = data
#role is pair: (owner, outputs_position)
def __get_role(self):
return self._role
def __set_role(self, role):
owner, index = role
if self._role is not None:
# this is either an error or a no-op
_owner, _index = self._role
if _owner is not owner:
raise ValueError("Result %s already has an owner." % self)
if _index != index:
raise ValueError("Result %s was already mapped to a different index." % self)
return # because _owner is owner and _index == index
self._role = role
role = property(__get_role, __set_role)
#owner is role[0]
def __get_owner(self):
if self._role is None: return None
if self.replaced: raise ResultBase.BrokenLinkError()
return self._role[0]
owner = property(__get_owner,
doc = "Op of which this Result is an output, or None if role is None")
#index is role[1]
def __get_index(self):
if self._role is None: return None
if self.replaced: raise ResultBase.BrokenLinkError()
return self._role[1]
index = property(__get_index,
doc = "position of self in owner's outputs, or None if role is None")
# assigning to self.data will invoke self.data_filter(value) if that
# function is defined
def __get_data(self):
return self._data
def __set_data(self, data):
if self.replaced: raise ResultBase.BrokenLinkError()
if self.constant: raise Exception('cannot set constant ResultBase')
try:
self._data = self.data_filter(self, data)
except ResultBase.AbstractFunction:
self._data = data
data = property(__get_data, __set_data,
doc = "The storage associated with this result")
def data_filter(self, data):
"""(abstract) Return an appropriate _data based on data."""
raise ResultBase.AbstractFunction()
# replaced
def __get_replaced(self): return isinstance(self._role, ResultBase.BrokenLink)
def __set_replaced(self, replace):
if replace == self.replaced: return
if replace:
self._role = ResultBase.BrokenLink(self._role)
else:
self._role = self._role.old_role
replaced = property(__get_replaced, __set_replaced, doc = "has this Result been replaced?")
class _test_ResultBase(unittest.TestCase):
def test_0(self):
r = ResultBase()
class Numpy2(ResultBase):
"""Result storing a numpy ndarray"""
__slots__ = ['_dtype', '_shape', ]
class ShapeUnknown: pass # TODO: use this as the shape of uncomputed ndarrays of unknown shape
class StateError(Exception): pass
def __init__(self, role=None, data=None, constant=False):
if isinstance(data, (tuple, list)): # unallocated setup
shape, dtype = data
ResultBase.__init__(self, role, data=None, constant=constant)
self._shape = shape
self._dtype = dtype
else: # allocated setup
ResultBase.__init__(self, role, data, constant)
################################
# ResultBase
#
def data_filter(self, data):
#return numpy.asarray(data) #TODO: consider whether this is correct
if isinstance(data, numpy.ndarray):
return data
raise TypeError('failed to filter data to ndarray', data)
################################
# Numpy2 specific functionality
#
__array__ = property(lambda self: self._data.__array__ )
__array_struct__ = property(lambda self: self._data.__array_struct__ )
def data_set_inplace(self, data):
raise NotImplementedError()
def data_alloc(self):
self.data = numpy.ndarray(self.shape, self.dtype)
# self._dtype is used when self._data hasn't been set yet
def __dtype_get(self):
if self._data is None:
return self._dtype
else:
return self._data.dtype
def __dtype_set(self, dtype):
if self._data is None:
self._dtype = dtype
else:
raise StateError('cannot set dtype after data has been set')
dtype = property(__dtype_get, __dtype_set)
# self._shape is used when self._data hasn't been set yet
def __shape_get(self):
if self._data is None:
return self._shape
else:
return self._data.shape
def __shape_set(self, shape):
if self._data is None:
self._shape = shape
else:
raise StateError('cannot set shape after data has been set')
shape = property(__shape_get, __shape_set)
class _test_Numpy2(unittest.TestCase):
def test_0(self):
r = Numpy2()
def test_1(self):
o = numpy.ones((3,3))
r = Numpy2(data=o)
self.failUnless(r.data is o)
self.failUnless(r.shape == (3,3))
self.failUnless(str(r.dtype) == 'float64')
def test_2(self):
r = Numpy2(data=[(3,3),'int32'])
self.failUnless(r.data is None)
self.failUnless(r.shape == (3,3))
self.failUnless(str(r.dtype) == 'int32')
r.data_alloc()
self.failUnless(isinstance(r.data, numpy.ndarray))
self.failUnless(r.shape == (3,3))
self.failUnless(str(r.dtype) == 'int32')
def input(x): def input(x):
#NB: #static member initialization
# - automatically casting int to float seems wrong. if not hasattr(input, 'float_dtype'):
# - we want to be able to write y = x + 1 and maybe have the 1 casted to 1.0 input.float_dtype = 'float64'
# at some point to maximize speed right? input.int_dtype = 'int64'
# - But more important is the ability to store index values without them input.NN = NumpyR
# being cast to floating-point (can that cause incorrectness?)
if isinstance(x, numpy.ndarray): if isinstance(x, numpy.ndarray):
return NumpyR(x) return input.NN(x)
elif isinstance(x, int): elif isinstance(x, int):
z = numpy.zeros((), dtype = input.int_dtype) z = numpy.zeros((), dtype = input.int_dtype)
z += x z += x
return NumpyR(z) return input.NN(z)
elif isinstance(x, float): elif isinstance(x, float):
z = numpy.zeros((), dtype = input.float_dtype) z = numpy.zeros((), dtype = input.float_dtype)
z += x z += x
return NumpyR(z) return input.NN(z)
elif isinstance(x, gof.Result): elif isinstance(x, gof.Result):
raise TypeError("%s is already a result." % x) raise TypeError("%s is already a result." % x)
else: else:
return ResultValue(x) return ResultValue(x)
class _testCase_input(unittest.TestCase):
input.float_dtype = 'float64' def setUp(self):
input.int_dtype = 'int64' literal.hdb = {}
literal.udb = {}
def test_input_int(self):
w = input(3)
self.failUnless(isinstance(w, input.NN))
self.failUnless(str(w.data.dtype) == input.int_dtype)
self.failUnless(w.data == 3)
def test_input_float(self):
w = input(3.0)
self.failUnless(isinstance(w, input.NN))
self.failUnless(str(w.data.dtype) == input.float_dtype)
self.failUnless(w.data == 3.0)
def wrap(x): def wrap(x):
if isinstance(x, NumpyR): if isinstance(x, NumpyR):
return x return x
elif isinstance(x, Numpy2):
return x
elif isinstance(x, ResultValue): elif isinstance(x, ResultValue):
return x return x
elif isinstance(x, omega_op): elif isinstance(x, omega_op):
return x.out return x.out
else: else:
return literal(x) return literal(x)
class _testCase_wrap(unittest.TestCase): class _testCase_wrap(unittest.TestCase):
def setUp(self): def setUp(self):
literal.hdb = {} literal.hdb = {}
literal.udb = {} literal.udb = {}
def test_input_int(self):
w = input(3)
self.failUnless(isinstance(w, NumpyR))
self.failUnless(str(w.data.dtype) == input.int_dtype)
self.failUnless(w.data == 3)
def test_input_float(self):
w = input(3.0)
self.failUnless(isinstance(w, NumpyR))
self.failUnless(str(w.data.dtype) == input.float_dtype)
self.failUnless(w.data == 3.0)
def test_literal_int(self):
w = literal(3)
self.failUnless(isinstance(w, NumpyR))
self.failUnless(str(w.data.dtype) == input.int_dtype)
self.failUnless(w.data == 3)
def test_literal_float(self):
w = literal(3.0)
self.failUnless(isinstance(w, NumpyR))
self.failUnless(str(w.data.dtype) == input.float_dtype)
self.failUnless(w.data == 3.0)
def test_wrap_int(self): def test_wrap_int(self):
w = wrap(3) w = wrap(3)
self.failUnless(isinstance(w, NumpyR)) self.failUnless(isinstance(w, input.NN))
self.failUnless(str(w.data.dtype) == input.int_dtype) self.failUnless(str(w.data.dtype) == input.int_dtype)
self.failUnless(w.data == 3) self.failUnless(w.data == 3)
def test_wrap_float(self): def test_wrap_float(self):
w = wrap(3.0) w = wrap(3.0)
self.failUnless(isinstance(w, NumpyR)) self.failUnless(isinstance(w, input.NN))
self.failUnless(str(w.data.dtype) == input.float_dtype) self.failUnless(str(w.data.dtype) == input.float_dtype)
self.failUnless(w.data == 3.0) self.failUnless(w.data == 3.0)
...@@ -168,7 +353,7 @@ def literal(x): ...@@ -168,7 +353,7 @@ def literal(x):
if _hashable(x): if _hashable(x):
db = literal.hdb db = literal.hdb
key = (id(x),x) key = (type(x),x)
else: else:
db = literal.udb db = literal.udb
key = (id(x),) key = (id(x),)
...@@ -180,6 +365,33 @@ def literal(x): ...@@ -180,6 +365,33 @@ def literal(x):
rval.constant = True rval.constant = True
db[key] = rval db[key] = rval
return rval return rval
class _testCase_literal(unittest.TestCase):
def setUp(self):
literal.hdb = {}
literal.udb = {}
def test_int(self):
w = literal(3)
self.failUnless(isinstance(w, input.NN))
self.failUnless(str(w.data.dtype) == input.int_dtype)
self.failUnless(w.data == 3)
u = literal(1+2)
self.failUnless(u is w)
def test_float(self):
w = literal(3.0)
self.failUnless(isinstance(w, input.NN))
self.failUnless(str(w.data.dtype) == input.float_dtype)
self.failUnless(w.data == 3.0)
u = literal(1.0+2.0)
self.failUnless(u is w)
def test_mixed(self):
f = literal(2.0)
i = literal(2)
self.failUnless(i is not f)
...@@ -756,7 +968,10 @@ class NumpyR(gof.ResultValue): ...@@ -756,7 +968,10 @@ class NumpyR(gof.ResultValue):
def __getitem__(self, item): return get_slice(self, item) def __getitem__(self, item): return get_slice(self, item)
def __getslice__(self, *args): return get_slice(self, slice(*args)) def __getslice__(self, *args): return get_slice(self, slice(*args))
def wrap_producer(f): def wrap_producer(f):
class producer(omega_op): class producer(omega_op):
impl = f impl = f
...@@ -1075,7 +1290,7 @@ class _testCase_dot(unittest.TestCase): ...@@ -1075,7 +1290,7 @@ class _testCase_dot(unittest.TestCase):
try: try:
z = dot(x,y) z = dot(x,y)
except ValueError, e: except ValueError, e:
self.failUnless(str(e) == 'matrices are not aligned') self.failUnless(str(e) == 'objects are not aligned', e)
return return
self.fail() self.fail()
...@@ -1085,7 +1300,7 @@ class _testCase_dot(unittest.TestCase): ...@@ -1085,7 +1300,7 @@ class _testCase_dot(unittest.TestCase):
try: try:
z = dot(x,y) z = dot(x,y)
except ValueError, e: except ValueError, e:
self.failUnless(str(e) == 'matrices are not aligned') self.failUnless(str(e) == 'objects are not aligned', e)
return return
self.fail() self.fail()
def test_dot_fail_1_3(self): def test_dot_fail_1_3(self):
...@@ -1094,7 +1309,7 @@ class _testCase_dot(unittest.TestCase): ...@@ -1094,7 +1309,7 @@ class _testCase_dot(unittest.TestCase):
try: try:
z = dot(x,y) z = dot(x,y)
except ValueError, e: except ValueError, e:
self.failUnless(str(e) == 'objects are not aligned') self.failUnless(str(e) == 'objects are not aligned', e)
return return
self.fail() self.fail()
def test_dot_fail_2_1(self): def test_dot_fail_2_1(self):
...@@ -1103,7 +1318,7 @@ class _testCase_dot(unittest.TestCase): ...@@ -1103,7 +1318,7 @@ class _testCase_dot(unittest.TestCase):
try: try:
z = dot(x,y) z = dot(x,y)
except ValueError, e: except ValueError, e:
self.failUnless(str(e) == 'matrices are not aligned') self.failUnless(str(e) == 'objects are not aligned', e)
return return
self.fail() self.fail()
def test_dot_fail_2_2(self): def test_dot_fail_2_2(self):
...@@ -1112,7 +1327,7 @@ class _testCase_dot(unittest.TestCase): ...@@ -1112,7 +1327,7 @@ class _testCase_dot(unittest.TestCase):
try: try:
z = dot(x,y) z = dot(x,y)
except ValueError, e: except ValueError, e:
self.failUnless(str(e) == 'matrices are not aligned') self.failUnless(str(e) == 'objects are not aligned', e)
return return
self.fail() self.fail()
def test_dot_fail_2_3(self): def test_dot_fail_2_3(self):
...@@ -1121,7 +1336,7 @@ class _testCase_dot(unittest.TestCase): ...@@ -1121,7 +1336,7 @@ class _testCase_dot(unittest.TestCase):
try: try:
z = dot(x,y) z = dot(x,y)
except ValueError, e: except ValueError, e:
self.failUnless(str(e) == 'objects are not aligned') self.failUnless(str(e) == 'objects are not aligned', e)
return return
self.fail() self.fail()
def test_dot_fail_3_1(self): def test_dot_fail_3_1(self):
...@@ -1130,7 +1345,7 @@ class _testCase_dot(unittest.TestCase): ...@@ -1130,7 +1345,7 @@ class _testCase_dot(unittest.TestCase):
try: try:
z = dot(x,y) z = dot(x,y)
except ValueError, e: except ValueError, e:
self.failUnless(str(e) == 'objects are not aligned') self.failUnless(str(e) == 'objects are not aligned', e)
return return
self.fail() self.fail()
def test_dot_fail_3_2(self): def test_dot_fail_3_2(self):
...@@ -1139,7 +1354,7 @@ class _testCase_dot(unittest.TestCase): ...@@ -1139,7 +1354,7 @@ class _testCase_dot(unittest.TestCase):
try: try:
z = dot(x,y) z = dot(x,y)
except ValueError, e: except ValueError, e:
self.failUnless(str(e) == 'objects are not aligned') self.failUnless(str(e) == 'objects are not aligned', e)
return return
self.fail() self.fail()
def test_dot_fail_3_3(self): def test_dot_fail_3_3(self):
...@@ -1148,7 +1363,7 @@ class _testCase_dot(unittest.TestCase): ...@@ -1148,7 +1363,7 @@ class _testCase_dot(unittest.TestCase):
try: try:
z = dot(x,y) z = dot(x,y)
except ValueError, e: except ValueError, e:
self.failUnless(str(e) == 'objects are not aligned') self.failUnless(str(e) == 'objects are not aligned', e)
return return
self.fail() self.fail()
...@@ -1425,15 +1640,19 @@ class _testCase_power(unittest.TestCase): ...@@ -1425,15 +1640,19 @@ class _testCase_power(unittest.TestCase):
numpy.random.seed(44) numpy.random.seed(44)
def tearDown(self): def tearDown(self):
pop_mode() pop_mode()
def test_0(self): def test_0(self):
r = numpy.random.rand(50) r = numpy.random.rand(50)
er = exp(r)
ler = log(er)
a,b = numpy.max(ler-r), numpy.min(ler-r) exp_r = exp(r)
self.failUnless(a < 1.0e-13 and b > -1.0e-13, 'exp and log are not inverses') self.failUnless( _approx_eq(exp_r, numpy.exp(r)))
log_exp_r = log(exp_r)
self.failUnless( _approx_eq(log_exp_r, r))
def test_1(self):
r = numpy.random.rand(50)
r2 = pow(r,2)
self.failUnless( _approx_eq(r2, r*r))
## Others ## ## Others ##
......
...@@ -58,6 +58,10 @@ def compute(*nodes): ...@@ -58,6 +58,10 @@ def compute(*nodes):
"""Recursively evaluate each node (in a quick & dirty way).""" """Recursively evaluate each node (in a quick & dirty way)."""
compute_from(nodes, set()) compute_from(nodes, set())
def is_result(obj):
"""Return True iff obj provides the interface of a Result"""
attr_list = 'data', 'owner'
return all([hasattr(obj, attr) for attr in attr_list])
class ForbidConstantOverwrite(features.Listener, features.Constraint): class ForbidConstantOverwrite(features.Listener, features.Constraint):
...@@ -460,8 +464,7 @@ class PythonOp(Op): ...@@ -460,8 +464,7 @@ class PythonOp(Op):
Op.__init__(self, inputs, self.gen_outputs()) Op.__init__(self, inputs, self.gen_outputs())
def __validate__(self): def __validate__(self):
for input in self.inputs: return all([ is_result(i) for i in self.inputs])
assert isinstance(input, ResultValue)
def gen_outputs(self): def gen_outputs(self):
return [ResultValue() for i in xrange(self.nout)] return [ResultValue() for i in xrange(self.nout)]
......
import gof import gof
from gof.lib import compute_from from gof.lib import compute_from, is_result
import core import core
class Grad(object): class Grad(object):
...@@ -173,7 +173,7 @@ class update_gradient_via_grad: ...@@ -173,7 +173,7 @@ class update_gradient_via_grad:
""" """
inputgs = self.grad(*(self.inputs + [grad_d[output] for output in self.outputs])) inputgs = self.grad(*(self.inputs + [grad_d[output] for output in self.outputs]))
if len(self.inputs) == 1 and isinstance(inputgs, gof.ResultValue): if len(self.inputs) == 1 and is_result(inputgs):
inputgs = [inputgs] inputgs = [inputgs]
else: else:
assert len(inputgs) == len(self.inputs) assert len(inputgs) == len(self.inputs)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论