broken

上级 460f6b78
import time import time, unittest
import numpy
import gof import gof
import gof.lib import gof.lib
...@@ -170,13 +171,13 @@ class prog(gof.Prog): ...@@ -170,13 +171,13 @@ class prog(gof.Prog):
""" """
if check_uncomputed: if check_uncomputed:
for input in self.env.inputs: for input in self.env.inputs:
if input.data is core.UNCOMPUTED: if input.data is None:
raise Exception("You must provide a value for input %s!" % input) raise Exception("You must provide a value for input %s!" % input)
return gof.Prog.__call__(self) return gof.Prog.__call__(self)
def compute_orphans(self): def compute_orphans(self):
for orphan in self.env.orphans(): for orphan in self.env.orphans():
if orphan.data is core.UNCOMPUTED: if orphan.data is None:
if orphan.owner: if orphan.owner:
gof.lib.compute(orphan.owner) gof.lib.compute(orphan.owner)
else: else:
...@@ -200,3 +201,22 @@ def to_func(inputs, outputs): ...@@ -200,3 +201,22 @@ def to_func(inputs, outputs):
def single(*outputs, **kwargs): def single(*outputs, **kwargs):
return prog(gof.graph.inputs(outputs), outputs, **kwargs) return prog(gof.graph.inputs(outputs), outputs, **kwargs)
class _test_single(unittest.TestCase):
def setUp(self):
core.build_eval_mode()
numpy.random.seed(44)
def tearDown(self):
core.pop_mode()
def test_3(self):
a = core.Numpy2(data=numpy.ones((2,2)))
b = core.Numpy2(data=numpy.ones((2,2)))
c = core.add(a,b)
p = single(c)
p()
self.failUnless(core._approx_eq(c, numpy.ones((2,2))*2))
if __name__ == '__main__':
unittest.main()
差异被折叠。
...@@ -5,7 +5,7 @@ import graph ...@@ -5,7 +5,7 @@ import graph
from utils import ClsInit from utils import ClsInit
from err import GofError, GofTypeError, PropagationError from err import GofError, GofTypeError, PropagationError
from op import Op from op import Op
from result import Result from result import is_result
from features import Listener, Orderings, Constraint, Tool from features import Listener, Orderings, Constraint, Tool
import utils import utils
...@@ -14,10 +14,10 @@ __all__ = ['InconsistencyError', ...@@ -14,10 +14,10 @@ __all__ = ['InconsistencyError',
# class AliasDict(dict): # class AliasDict(dict):
# "Utility class to keep track of what Result has been replaced with what Result." # "Utility class to keep track of what result has been replaced with what result."
# def group(self, main, *keys): # def group(self, main, *keys):
# "Marks all the keys as having been replaced by the Result main." # "Marks all the keys as having been replaced by the result main."
# keys = [key for key in keys if key is not main] # keys = [key for key in keys if key is not main]
# if self.has_key(main): # if self.has_key(main):
# raise Exception("Only group results that have not been grouped before.") # raise Exception("Only group results that have not been grouped before.")
...@@ -284,9 +284,11 @@ class Env(graph.Graph): ...@@ -284,9 +284,11 @@ class Env(graph.Graph):
be raised if there are type mismatches. be raised if there are type mismatches.
""" """
# Assert that they are Result instances. # Assert that they are result instances.
assert isinstance(r, Result) if not is_result(r):
assert isinstance(new_r, Result) raise TypeError(r)
if not is_result(new_r):
raise TypeError(new_r)
# Save where we are so we can backtrack # Save where we are so we can backtrack
if consistency_check: if consistency_check:
......
from op import Op from op import Op
from result import Result #, HolderResult from result import Result, is_result
from utils import ClsInit, Keyword, AbstractFunctionError from utils import ClsInit, Keyword, AbstractFunctionError
import opt import opt
import env import env
...@@ -8,15 +8,14 @@ import features ...@@ -8,15 +8,14 @@ import features
import ext import ext
__all__ = ['UNCOMPUTED', __all__ = [ 'UNDEFINED',
'UNDEFINED',
'current_mode', 'current_mode',
'set_mode', 'set_mode',
'build_mode', 'build_mode',
'eval_mode', 'eval_mode',
'build_eval_mode', 'build_eval_mode',
'pop_mode', 'pop_mode',
'ResultValue', #'ResultValue',
'DummyOp', 'DummyOp',
'DummyRemover', 'DummyRemover',
'PythonOp', 'PythonOp',
...@@ -24,7 +23,6 @@ __all__ = ['UNCOMPUTED', ...@@ -24,7 +23,6 @@ __all__ = ['UNCOMPUTED',
'make_static'] 'make_static']
UNCOMPUTED = Keyword("UNCOMPUTED", False)
UNDEFINED = Keyword("UNDEFINED", False) UNDEFINED = Keyword("UNDEFINED", False)
def make_static(cls, fname): def make_static(cls, fname):
...@@ -58,11 +56,6 @@ def compute(*nodes): ...@@ -58,11 +56,6 @@ def compute(*nodes):
"""Recursively evaluate each node (in a quick & dirty way).""" """Recursively evaluate each node (in a quick & dirty way)."""
compute_from(nodes, set()) compute_from(nodes, set())
def is_result(obj):
"""Return True iff obj provides the interface of a Result"""
attr_list = 'data', 'owner'
return all([hasattr(obj, attr) for attr in attr_list])
class ForbidConstantOverwrite(features.Listener, features.Constraint): class ForbidConstantOverwrite(features.Listener, features.Constraint):
def __init__(self, env): def __init__(self, env):
...@@ -105,7 +98,8 @@ class ForbidConstantOverwrite(features.Listener, features.Constraint): ...@@ -105,7 +98,8 @@ class ForbidConstantOverwrite(features.Listener, features.Constraint):
class ResultValue(Result): if 0:
class ResultValue(Result):
"""Augment Result to wrap a computed value. """Augment Result to wrap a computed value.
Attributes: Attributes:
...@@ -128,7 +122,7 @@ class ResultValue(Result): ...@@ -128,7 +122,7 @@ class ResultValue(Result):
__slots__ = ['data', 'spec', 'constant', 'up_to_date'] __slots__ = ['data', 'spec', 'constant', 'up_to_date']
def __init__(self, x = UNCOMPUTED, constant = False): def __init__(self, x = None, constant = False):
self.constant = False self.constant = False
self.set_value(x) # allow set_value before constant = True self.set_value(x) # allow set_value before constant = True
self.constant = constant self.constant = constant
...@@ -152,7 +146,7 @@ class ResultValue(Result): ...@@ -152,7 +146,7 @@ class ResultValue(Result):
raise Exception("This Result is a constant. Its value cannot be changed.") raise Exception("This Result is a constant. Its value cannot be changed.")
if value is None or value is UNCOMPUTED: if value is None or value is UNCOMPUTED:
self.data = UNCOMPUTED self.data = UNCOMPUTED
elif isinstance(value, ResultValue): elif is_result(value):
self.set_value(value.data) self.set_value(value.data)
else: else:
try: try:
...@@ -467,7 +461,8 @@ class PythonOp(Op): ...@@ -467,7 +461,8 @@ class PythonOp(Op):
return all([ is_result(i) for i in self.inputs]) return all([ is_result(i) for i in self.inputs])
def gen_outputs(self): def gen_outputs(self):
return [ResultValue() for i in xrange(self.nout)] raise NotImplementedError()
#return [ResultValue() for i in xrange(self.nout)]
def view_map(self): return {} def view_map(self): return {}
...@@ -500,7 +495,7 @@ class PythonOp(Op): ...@@ -500,7 +495,7 @@ class PythonOp(Op):
return answer return answer
def check_input(self, input): def check_input(self, input):
if input.data is UNCOMPUTED: if input.data is None:
raise ValueError("Uncomputed input: %s in %s" % (input, self)) raise ValueError("Uncomputed input: %s in %s" % (input, self))
if not self.input_is_up_to_date(input): if not self.input_is_up_to_date(input):
raise ValueError("Input is out of date: %s in %s" % (input, self)) raise ValueError("Input is out of date: %s in %s" % (input, self))
......
...@@ -10,7 +10,7 @@ from err import GofError ...@@ -10,7 +10,7 @@ from err import GofError
from utils import AbstractFunctionError from utils import AbstractFunctionError
__all__ = ['Result', 'BrokenLink', 'BrokenLinkError'] __all__ = ['is_result', 'Result', 'BrokenLink', 'BrokenLinkError']
class BrokenLink: class BrokenLink:
...@@ -40,6 +40,11 @@ class BrokenLinkError(GofError): ...@@ -40,6 +40,11 @@ class BrokenLinkError(GofError):
# Result # Result
############################ ############################
def is_result(obj):
"""Return True iff obj provides the interface of a Result"""
attr_list = 'owner',
return all([hasattr(obj, attr) for attr in attr_list])
class Result(object): class Result(object):
"""Storage node for data in a graph of Op instances. """Storage node for data in a graph of Op instances.
......
...@@ -59,23 +59,21 @@ class Grad(object): ...@@ -59,23 +59,21 @@ class Grad(object):
""" """
if dr is core.UNDEFINED: if dr is core.UNDEFINED:
# nothing to do # nothing to do
pass return
else:
if r.data is core.UNCOMPUTED or dr.data is core.UNCOMPUTED: if r.data is not None and dr.data is not None:
pass # no sanity checking if not hasattr(r, 'shape'):
else: # some sanity checking to catch obvious mistakes
if not hasattr(r.data, 'shape'):
raise ValueError(('Grad::add r lacks shape: type=', raise ValueError(('Grad::add r lacks shape: type=',
type(r.data))) type(r)))
if not hasattr(dr.data, 'shape'): if not hasattr(dr, 'shape'):
raise ValueError(('Grad::add dr lacks shape: type=', raise ValueError(('Grad::add dr lacks shape: type=',
type(dr.data))) type(dr)))
if r.data.shape != dr.data.shape: if r.shape != dr.shape:
raise ValueError(('Grad::add r, dr shape mismatch', raise ValueError(('Grad::add r, dr shape mismatch',
r.data.shape, dr.data.shape)) r.shape, dr.shape))
# prevent 'r' from being re-calculated by self.__call__ in 'build_eval' mode # prevent 'r' from being re-calculated by self.__call__ in 'build_eval' mode
if r.data is not core.UNCOMPUTED: if r.computed:
self._compute_history.add(r) self._compute_history.add(r)
# add dr to self[r] # add dr to self[r]
...@@ -213,14 +211,14 @@ class _testCase (unittest.TestCase): ...@@ -213,14 +211,14 @@ class _testCase (unittest.TestCase):
diff = wwi - ident diff = wwi - ident
ssdiff = core.sum((diff**2)) ssdiff = core.sum((diff**2))
if i == 0: if i == 0:
str0 = str_ssdiff = str(ssdiff) str0 = str_ssdiff = str(ssdiff.data)
#print ssdiff #print ssdiff
g = grad(ssdiff) g = grad(ssdiff)
gw = g(w) gw = g(w)
w.data += -0.4 * gw.data w.data[:] += -0.4 * gw.data
return str0, str(ssdiff) return str0, str(ssdiff.data)
def matinv_compiled(self, dim): def matinv_compiled(self, dim):
w = core.wrap(numpy.random.rand(dim,dim)) w = core.wrap(numpy.random.rand(dim,dim))
...@@ -230,7 +228,7 @@ class _testCase (unittest.TestCase): ...@@ -230,7 +228,7 @@ class _testCase (unittest.TestCase):
wwi = core.dot(w, wi) wwi = core.dot(w, wi)
diff = wwi - ident diff = wwi - ident
ssdiff = core.sum((diff**2)) ssdiff = core.sum((diff**2))
str0 = str_ssdiff = str(ssdiff) str0 = str_ssdiff = str(ssdiff.data)
#print ssdiff #print ssdiff
g = grad(ssdiff) g = grad(ssdiff)
...@@ -240,9 +238,9 @@ class _testCase (unittest.TestCase): ...@@ -240,9 +238,9 @@ class _testCase (unittest.TestCase):
for i in xrange(300): for i in xrange(300):
prog() prog()
w.data += -0.4 * gw.data w.data[:] += -0.4 * gw.data
return str0, str(ssdiff) return str0, str(ssdiff.data)
def test0(self): def test0(self):
"""Matrix inversion by gradient descent (eval mode)""" """Matrix inversion by gradient descent (eval mode)"""
......
...@@ -9,7 +9,7 @@ import grad ...@@ -9,7 +9,7 @@ import grad
# Wrapper type # Wrapper type
class SparseR(gof.ResultValue): class SparseR(core.ResultBase):
""" """
Attribute: Attribute:
format - a subclass of sparse.spmatrix indicating self.data.__class__ format - a subclass of sparse.spmatrix indicating self.data.__class__
...@@ -22,12 +22,17 @@ class SparseR(gof.ResultValue): ...@@ -22,12 +22,17 @@ class SparseR(gof.ResultValue):
Notes: Notes:
""" """
def __init__(self, x = core.UNCOMPUTED, constant = False, def __init__(self, data=None, role=None, constant = False,
format = sparse.csr_matrix): format = sparse.csr_matrix):
gof.ResultValue.__init__(self, x, constant) core.ResultBase.__init__(self, role, data, constant)
self.format = isinstance(x, sparse.spmatrix) and x.__class__ or format if isinstance(data, sparse.spmatrix):
self.format = data.__class__
else:
self.format = format
self._dtype = None
self._shape = None
def set_value_filter(self, value): def data_filter(self, value):
if isinstance(value, sparse.spmatrix): return value if isinstance(value, sparse.spmatrix): return value
return sparse.csr_matrix(value) return sparse.csr_matrix(value)
...@@ -36,6 +41,32 @@ class SparseR(gof.ResultValue): ...@@ -36,6 +41,32 @@ class SparseR(gof.ResultValue):
T = property(lambda self: transpose(self), doc = "Return aliased transpose") T = property(lambda self: transpose(self), doc = "Return aliased transpose")
# self._dtype is used when self._data hasn't been set yet
def __dtype_get(self):
if self._data is None:
return self._dtype
else:
return self._data.dtype
def __dtype_set(self, dtype):
if self._data is None:
self._dtype = dtype
else:
raise StateError('cannot set dtype after data has been set')
dtype = property(__dtype_get, __dtype_set)
# self._shape is used when self._data hasn't been set yet
def __shape_get(self):
if self._data is None:
return self._shape
else:
return self._data.shape
def __shape_set(self, shape):
if self._data is None:
self._shape = shape
else:
raise StateError('cannot set shape after data has been set')
shape = property(__shape_get, __shape_set)
# convenience base class # convenience base class
class op(gof.PythonOp, grad.update_gradient_via_grad): class op(gof.PythonOp, grad.update_gradient_via_grad):
pass pass
...@@ -46,7 +77,7 @@ class op(gof.PythonOp, grad.update_gradient_via_grad): ...@@ -46,7 +77,7 @@ class op(gof.PythonOp, grad.update_gradient_via_grad):
# convert a sparse matrix to an ndarray # convert a sparse matrix to an ndarray
class sparse2dense(op): class sparse2dense(op):
def gen_outputs(self): return [core.NumpyR()] def gen_outputs(self): return [core.Numpy2()]
def impl(x): return numpy.asarray(x.todense()) def impl(x): return numpy.asarray(x.todense())
def grad(self, x, gz): def grad(self, x, gz):
if x.format is sparse.coo_matrix: return dense2coo(gz) if x.format is sparse.coo_matrix: return dense2coo(gz)
...@@ -135,7 +166,7 @@ class _testCase_dot(unittest.TestCase): ...@@ -135,7 +166,7 @@ class _testCase_dot(unittest.TestCase):
def test_basic0(self): def test_basic0(self):
for mtype in [sparse.csc_matrix, sparse.csr_matrix]: for mtype in [sparse.csc_matrix, sparse.csr_matrix]:
x = SparseR(mtype(sparse.speye(5,3))) x = SparseR(mtype(sparse.speye(5,3)))
y = core.NumpyR(numpy.random.rand(3, 2)) y = core.wrap(numpy.random.rand(3, 2))
z = dot(x,y) z = dot(x,y)
self.failUnless(z.data.shape == (5,2)) self.failUnless(z.data.shape == (5,2))
...@@ -171,7 +202,7 @@ class _testCase_dot(unittest.TestCase): ...@@ -171,7 +202,7 @@ class _testCase_dot(unittest.TestCase):
self.failUnless(z.data.shape == ab.shape) self.failUnless(z.data.shape == ab.shape)
self.failUnless(type(z.data) == type(ab)) self.failUnless(type(z.data) == type(ab))
def test_graph_bprop0(self): def test_graph_bprop0(self):
x = core.NumpyR(numpy.random.rand(10,2)) x = core.wrap(numpy.random.rand(10,2))
w = SparseR(sparse.csr_matrix(numpy.asarray([[1, 0, 3, 0, 5], [0, 0, -2, 0, w = SparseR(sparse.csr_matrix(numpy.asarray([[1, 0, 3, 0, 5], [0, 0, -2, 0,
0]],dtype='float64'))) 0]],dtype='float64')))
...@@ -187,10 +218,10 @@ class _testCase_dot(unittest.TestCase): ...@@ -187,10 +218,10 @@ class _testCase_dot(unittest.TestCase):
g(w).data[1,4] = 0 g(w).data[1,4] = 0
w.data = -lr * g(w).data + w.data w.data = -lr * g(w).data + w.data
self.failUnless('3.08560636025' == str(loss)) self.failUnless('3.08560636025' == str(loss.data))
def test_graph_bprop1(self): def test_graph_bprop1(self):
x = core.NumpyR(numpy.random.rand(10,2)) x = core.wrap(numpy.random.rand(10,2))
w = SparseR(sparse.csr_matrix(numpy.asarray([[1, 0, 3, 0, 5], [0, 0, -2, 0, w = SparseR(sparse.csr_matrix(numpy.asarray([[1, 0, 3, 0, 5], [0, 0, -2, 0,
0]],dtype='float64'))) 0]],dtype='float64')))
...@@ -205,7 +236,7 @@ class _testCase_dot(unittest.TestCase): ...@@ -205,7 +236,7 @@ class _testCase_dot(unittest.TestCase):
g(w).data[1,4] = 0 g(w).data[1,4] = 0
w.data = -lr * g(w).data + w.data w.data = -lr * g(w).data + w.data
self.failUnless('3.08560636025' == str(loss)) self.failUnless('3.08560636025' == str(loss.data))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论