broken

上级 460f6b78
import time
import time, unittest
import numpy
import gof
import gof.lib
......@@ -170,13 +171,13 @@ class prog(gof.Prog):
"""
if check_uncomputed:
for input in self.env.inputs:
if input.data is core.UNCOMPUTED:
if input.data is None:
raise Exception("You must provide a value for input %s!" % input)
return gof.Prog.__call__(self)
def compute_orphans(self):
for orphan in self.env.orphans():
if orphan.data is core.UNCOMPUTED:
if orphan.data is None:
if orphan.owner:
gof.lib.compute(orphan.owner)
else:
......@@ -200,3 +201,22 @@ def to_func(inputs, outputs):
def single(*outputs, **kwargs):
return prog(gof.graph.inputs(outputs), outputs, **kwargs)
class _test_single(unittest.TestCase):
def setUp(self):
core.build_eval_mode()
numpy.random.seed(44)
def tearDown(self):
core.pop_mode()
def test_3(self):
a = core.Numpy2(data=numpy.ones((2,2)))
b = core.Numpy2(data=numpy.ones((2,2)))
c = core.add(a,b)
p = single(c)
p()
self.failUnless(core._approx_eq(c, numpy.ones((2,2))*2))
if __name__ == '__main__':
unittest.main()
差异被折叠。
......@@ -5,7 +5,7 @@ import graph
from utils import ClsInit
from err import GofError, GofTypeError, PropagationError
from op import Op
from result import Result
from result import is_result
from features import Listener, Orderings, Constraint, Tool
import utils
......@@ -14,10 +14,10 @@ __all__ = ['InconsistencyError',
# class AliasDict(dict):
# "Utility class to keep track of what Result has been replaced with what Result."
# "Utility class to keep track of what result has been replaced with what result."
# def group(self, main, *keys):
# "Marks all the keys as having been replaced by the Result main."
# "Marks all the keys as having been replaced by the result main."
# keys = [key for key in keys if key is not main]
# if self.has_key(main):
# raise Exception("Only group results that have not been grouped before.")
......@@ -284,9 +284,11 @@ class Env(graph.Graph):
be raised if there are type mismatches.
"""
# Assert that they are Result instances.
assert isinstance(r, Result)
assert isinstance(new_r, Result)
# Assert that they are result instances.
if not is_result(r):
raise TypeError(r)
if not is_result(new_r):
raise TypeError(new_r)
# Save where we are so we can backtrack
if consistency_check:
......
from op import Op
from result import Result #, HolderResult
from result import Result, is_result
from utils import ClsInit, Keyword, AbstractFunctionError
import opt
import env
......@@ -8,15 +8,14 @@ import features
import ext
__all__ = ['UNCOMPUTED',
'UNDEFINED',
__all__ = [ 'UNDEFINED',
'current_mode',
'set_mode',
'build_mode',
'eval_mode',
'build_eval_mode',
'pop_mode',
'ResultValue',
#'ResultValue',
'DummyOp',
'DummyRemover',
'PythonOp',
......@@ -24,7 +23,6 @@ __all__ = ['UNCOMPUTED',
'make_static']
UNCOMPUTED = Keyword("UNCOMPUTED", False)
UNDEFINED = Keyword("UNDEFINED", False)
def make_static(cls, fname):
......@@ -58,11 +56,6 @@ def compute(*nodes):
"""Recursively evaluate each node (in a quick & dirty way)."""
compute_from(nodes, set())
def is_result(obj):
"""Return True iff obj provides the interface of a Result"""
attr_list = 'data', 'owner'
return all([hasattr(obj, attr) for attr in attr_list])
class ForbidConstantOverwrite(features.Listener, features.Constraint):
def __init__(self, env):
......@@ -105,80 +98,81 @@ class ForbidConstantOverwrite(features.Listener, features.Constraint):
class ResultValue(Result):
"""Augment Result to wrap a computed value.
if 0:
class ResultValue(Result):
"""Augment Result to wrap a computed value.
Attributes:
data -
spec -
constant -
up_to_date -
Attributes:
data -
spec -
constant -
up_to_date -
Properties:
Properties:
Methods:
Methods:
set_value_filter - ABSTRACT
set_value_inplace - ABSTRACT
alloc - ABSTRACT
set_value_filter - ABSTRACT
set_value_inplace - ABSTRACT
alloc - ABSTRACT
Notes:
Notes:
"""
"""
__slots__ = ['data', 'spec', 'constant', 'up_to_date']
def __init__(self, x = UNCOMPUTED, constant = False):
self.constant = False
self.set_value(x) # allow set_value before constant = True
self.constant = constant
self.up_to_date = True
self.refresh() # to set spec
__slots__ = ['data', 'spec', 'constant', 'up_to_date']
def __str__(self): return str(self.data)
def __init__(self, x = None, constant = False):
self.constant = False
self.set_value(x) # allow set_value before constant = True
self.constant = constant
self.up_to_date = True
self.refresh() # to set spec
def __str__(self): return str(self.data)
def __repr__(self): return repr(self.data)
def __repr__(self): return repr(self.data)
#TODO: document this function, what does it do?
def refresh(self): self.spec = id(self.data)
#TODO: document this function, what does it do?
def refresh(self): self.spec = id(self.data)
####################################################
#
# Functionality provided by this class
#
####################################################
#
# Functionality provided by this class
#
def set_value(self, value):
if self.constant:
raise Exception("This Result is a constant. Its value cannot be changed.")
if value is None or value is UNCOMPUTED:
self.data = UNCOMPUTED
elif isinstance(value, ResultValue):
self.set_value(value.data)
else:
try:
self.data = self.set_value_filter(value)
except AbstractFunctionError, e:
self.data = value
def set_value(self, value):
if self.constant:
raise Exception("This Result is a constant. Its value cannot be changed.")
if value is None or value is UNCOMPUTED:
self.data = UNCOMPUTED
elif is_result(value):
self.set_value(value.data)
else:
try:
self.data = self.set_value_filter(value)
except AbstractFunctionError, e:
self.data = value
self.up_to_date = True
self.refresh()
self.up_to_date = True
self.refresh()
####################################################
#
# Pure virtual functions for subclasses to implement
#
####################################################
#
# Pure virtual functions for subclasses to implement
#
# Perform error checking or automatic conversion of value, and return the
# result (which will be stored as self.data)
# Called by: set_value()
def set_value_filter(self, value): raise AbstractFunctionError()
# Perform error checking or automatic conversion of value, and return the
# result (which will be stored as self.data)
# Called by: set_value()
def set_value_filter(self, value): raise AbstractFunctionError()
# For mutable data types, overwrite the current contents with value
# Also, call refresh and set up_to_date = True
def set_value_inplace(self, value): raise AbstractFunctionError()
# For mutable data types, overwrite the current contents with value
# Also, call refresh and set up_to_date = True
def set_value_inplace(self, value): raise AbstractFunctionError()
# Instantiate data (according to spec)
def alloc(self): raise AbstractFunctionError()
# Instantiate data (according to spec)
def alloc(self): raise AbstractFunctionError()
class DestroyHandler(features.Listener, features.Constraint, features.Orderings):
......@@ -467,7 +461,8 @@ class PythonOp(Op):
return all([ is_result(i) for i in self.inputs])
def gen_outputs(self):
return [ResultValue() for i in xrange(self.nout)]
raise NotImplementedError()
#return [ResultValue() for i in xrange(self.nout)]
def view_map(self): return {}
......@@ -500,7 +495,7 @@ class PythonOp(Op):
return answer
def check_input(self, input):
if input.data is UNCOMPUTED:
if input.data is None:
raise ValueError("Uncomputed input: %s in %s" % (input, self))
if not self.input_is_up_to_date(input):
raise ValueError("Input is out of date: %s in %s" % (input, self))
......
......@@ -10,7 +10,7 @@ from err import GofError
from utils import AbstractFunctionError
__all__ = ['Result', 'BrokenLink', 'BrokenLinkError']
__all__ = ['is_result', 'Result', 'BrokenLink', 'BrokenLinkError']
class BrokenLink:
......@@ -40,6 +40,11 @@ class BrokenLinkError(GofError):
# Result
############################
def is_result(obj):
"""Return True iff obj provides the interface of a Result"""
attr_list = 'owner',
return all([hasattr(obj, attr) for attr in attr_list])
class Result(object):
"""Storage node for data in a graph of Op instances.
......
......@@ -57,32 +57,30 @@ class Grad(object):
r may be uncomputed or NumpyR
"""
if dr is core.UNDEFINED:
if dr is core.UNDEFINED:
# nothing to do
pass
return
if r.data is not None and dr.data is not None:
if not hasattr(r, 'shape'):
raise ValueError(('Grad::add r lacks shape: type=',
type(r)))
if not hasattr(dr, 'shape'):
raise ValueError(('Grad::add dr lacks shape: type=',
type(dr)))
if r.shape != dr.shape:
raise ValueError(('Grad::add r, dr shape mismatch',
r.shape, dr.shape))
# prevent 'r' from being re-calculated by self.__call__ in 'build_eval' mode
if r.computed:
self._compute_history.add(r)
# add dr to self[r]
if r in self:
self[r] = self[r] + dr
else:
if r.data is core.UNCOMPUTED or dr.data is core.UNCOMPUTED:
pass # no sanity checking
else: # some sanity checking to catch obvious mistakes
if not hasattr(r.data, 'shape'):
raise ValueError(('Grad::add r lacks shape: type=',
type(r.data)))
if not hasattr(dr.data, 'shape'):
raise ValueError(('Grad::add dr lacks shape: type=',
type(dr.data)))
if r.data.shape != dr.data.shape:
raise ValueError(('Grad::add r, dr shape mismatch',
r.data.shape, dr.data.shape))
# prevent 'r' from being re-calculated by self.__call__ in 'build_eval' mode
if r.data is not core.UNCOMPUTED:
self._compute_history.add(r)
# add dr to self[r]
if r in self:
self[r] = self[r] + dr
else:
self[r] = dr
self[r] = dr
def bprop(self, maybe_redo=False):
"""Build a backpropagation graph.
......@@ -213,14 +211,14 @@ class _testCase (unittest.TestCase):
diff = wwi - ident
ssdiff = core.sum((diff**2))
if i == 0:
str0 = str_ssdiff = str(ssdiff)
str0 = str_ssdiff = str(ssdiff.data)
#print ssdiff
g = grad(ssdiff)
gw = g(w)
w.data += -0.4 * gw.data
w.data[:] += -0.4 * gw.data
return str0, str(ssdiff)
return str0, str(ssdiff.data)
def matinv_compiled(self, dim):
w = core.wrap(numpy.random.rand(dim,dim))
......@@ -230,7 +228,7 @@ class _testCase (unittest.TestCase):
wwi = core.dot(w, wi)
diff = wwi - ident
ssdiff = core.sum((diff**2))
str0 = str_ssdiff = str(ssdiff)
str0 = str_ssdiff = str(ssdiff.data)
#print ssdiff
g = grad(ssdiff)
......@@ -240,9 +238,9 @@ class _testCase (unittest.TestCase):
for i in xrange(300):
prog()
w.data += -0.4 * gw.data
w.data[:] += -0.4 * gw.data
return str0, str(ssdiff)
return str0, str(ssdiff.data)
def test0(self):
"""Matrix inversion by gradient descent (eval mode)"""
......
......@@ -9,7 +9,7 @@ import grad
# Wrapper type
class SparseR(gof.ResultValue):
class SparseR(core.ResultBase):
"""
Attribute:
format - a subclass of sparse.spmatrix indicating self.data.__class__
......@@ -22,12 +22,17 @@ class SparseR(gof.ResultValue):
Notes:
"""
def __init__(self, x = core.UNCOMPUTED, constant = False,
def __init__(self, data=None, role=None, constant = False,
format = sparse.csr_matrix):
gof.ResultValue.__init__(self, x, constant)
self.format = isinstance(x, sparse.spmatrix) and x.__class__ or format
core.ResultBase.__init__(self, role, data, constant)
if isinstance(data, sparse.spmatrix):
self.format = data.__class__
else:
self.format = format
self._dtype = None
self._shape = None
def set_value_filter(self, value):
def data_filter(self, value):
if isinstance(value, sparse.spmatrix): return value
return sparse.csr_matrix(value)
......@@ -36,6 +41,32 @@ class SparseR(gof.ResultValue):
T = property(lambda self: transpose(self), doc = "Return aliased transpose")
# self._dtype is used when self._data hasn't been set yet
def __dtype_get(self):
if self._data is None:
return self._dtype
else:
return self._data.dtype
def __dtype_set(self, dtype):
if self._data is None:
self._dtype = dtype
else:
raise StateError('cannot set dtype after data has been set')
dtype = property(__dtype_get, __dtype_set)
# self._shape is used when self._data hasn't been set yet
def __shape_get(self):
if self._data is None:
return self._shape
else:
return self._data.shape
def __shape_set(self, shape):
if self._data is None:
self._shape = shape
else:
raise StateError('cannot set shape after data has been set')
shape = property(__shape_get, __shape_set)
# convenience base class
class op(gof.PythonOp, grad.update_gradient_via_grad):
pass
......@@ -46,7 +77,7 @@ class op(gof.PythonOp, grad.update_gradient_via_grad):
# convert a sparse matrix to an ndarray
class sparse2dense(op):
def gen_outputs(self): return [core.NumpyR()]
def gen_outputs(self): return [core.Numpy2()]
def impl(x): return numpy.asarray(x.todense())
def grad(self, x, gz):
if x.format is sparse.coo_matrix: return dense2coo(gz)
......@@ -135,7 +166,7 @@ class _testCase_dot(unittest.TestCase):
def test_basic0(self):
for mtype in [sparse.csc_matrix, sparse.csr_matrix]:
x = SparseR(mtype(sparse.speye(5,3)))
y = core.NumpyR(numpy.random.rand(3, 2))
y = core.wrap(numpy.random.rand(3, 2))
z = dot(x,y)
self.failUnless(z.data.shape == (5,2))
......@@ -171,7 +202,7 @@ class _testCase_dot(unittest.TestCase):
self.failUnless(z.data.shape == ab.shape)
self.failUnless(type(z.data) == type(ab))
def test_graph_bprop0(self):
x = core.NumpyR(numpy.random.rand(10,2))
x = core.wrap(numpy.random.rand(10,2))
w = SparseR(sparse.csr_matrix(numpy.asarray([[1, 0, 3, 0, 5], [0, 0, -2, 0,
0]],dtype='float64')))
......@@ -187,10 +218,10 @@ class _testCase_dot(unittest.TestCase):
g(w).data[1,4] = 0
w.data = -lr * g(w).data + w.data
self.failUnless('3.08560636025' == str(loss))
self.failUnless('3.08560636025' == str(loss.data))
def test_graph_bprop1(self):
x = core.NumpyR(numpy.random.rand(10,2))
x = core.wrap(numpy.random.rand(10,2))
w = SparseR(sparse.csr_matrix(numpy.asarray([[1, 0, 3, 0, 5], [0, 0, -2, 0,
0]],dtype='float64')))
......@@ -205,7 +236,7 @@ class _testCase_dot(unittest.TestCase):
g(w).data[1,4] = 0
w.data = -lr * g(w).data + w.data
self.failUnless('3.08560636025' == str(loss))
self.failUnless('3.08560636025' == str(loss.data))
if __name__ == '__main__':
unittest.main()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论