m

上级 efc14dba
...@@ -11,7 +11,8 @@ import numpy ...@@ -11,7 +11,8 @@ import numpy
from scipy import weave from scipy import weave
import gof import gof
from gof import current_mode, set_mode, build_mode, eval_mode, build_eval_mode, pop_mode, UNCOMPUTED, UNDEFINED, PythonR from gof import current_mode, set_mode, build_mode, eval_mode, build_eval_mode
from gof import pop_mode, UNCOMPUTED, UNDEFINED, ResultValue
import type_spec import type_spec
import cutils import cutils
...@@ -105,12 +106,12 @@ def input(x): ...@@ -105,12 +106,12 @@ def input(x):
elif isinstance(x, gof.Result): elif isinstance(x, gof.Result):
raise TypeError("%s is already a result." % x) raise TypeError("%s is already a result." % x)
else: else:
return PythonR(x) return ResultValue(x)
def wrap(x): def wrap(x):
if isinstance(x, NumpyR): if isinstance(x, NumpyR):
return x return x
elif isinstance(x, PythonR): elif isinstance(x, ResultValue):
return x return x
elif isinstance(x, omega_op): elif isinstance(x, omega_op):
return x.out return x.out
...@@ -144,7 +145,7 @@ def _literal_unhashable(x): ...@@ -144,7 +145,7 @@ def _literal_unhashable(x):
return r return r
def literal(x): def literal(x):
"""Return a PythonR instance wrapping a literal.""" """Return a ResultValue instance wrapping a literal."""
if _hashable(x): if _hashable(x):
return _literal_hashable(x) return _literal_hashable(x)
else: else:
...@@ -253,18 +254,18 @@ class omega_op(gof.PythonOp): ...@@ -253,18 +254,18 @@ class omega_op(gof.PythonOp):
self.grad(*(self.inputs + [grad_d[output] for output in self.outputs])) self.grad(*(self.inputs + [grad_d[output] for output in self.outputs]))
In general, grad() should return a list of PythonR instances whose In general, grad() should return a list of ResultValue instances whose
length matches that of self.inputs, and whose elements are the length matches that of self.inputs, and whose elements are the
gradients of self.inputs. gradients of self.inputs.
There is a (but often used) special feature in place to automatically There is a (but often used) special feature in place to automatically
wrap the return value of grad() in a list if it is a PythonR instance wrap the return value of grad() in a list if it is a ResultValue instance
and the op is unary. This makes many grad implementations a little and the op is unary. This makes many grad implementations a little
cuter. cuter.
""" """
inputgs = self.grad(*(self.inputs + [grad_d[output] for output in self.outputs])) inputgs = self.grad(*(self.inputs + [grad_d[output] for output in self.outputs]))
if len(self.inputs) == 1 and isinstance(inputgs, gof.PythonR): if len(self.inputs) == 1 and isinstance(inputgs, gof.ResultValue):
inputgs = [inputgs] inputgs = [inputgs]
else: else:
assert len(inputgs) == len(self.inputs) assert len(inputgs) == len(self.inputs)
...@@ -660,9 +661,10 @@ def scalar_switch(normal_f, scalar_f, scalar_f_reverse = None): ...@@ -660,9 +661,10 @@ def scalar_switch(normal_f, scalar_f, scalar_f_reverse = None):
return normal_f(x, y) return normal_f(x, y)
return f return f
class NumpyR(gof.PythonR): class NumpyR(gof.ResultValue):
"""The class for storing ndarray return values from omega ops. """The class for storing ndarray return values from omega ops.
The class provides additional functionality compared to the normal PythonR: The class provides additional functionality compared to the normal
ResultValue:
- operator overloads that correspond to omega ops such as add() and scale() - operator overloads that correspond to omega ops such as add() and scale()
- special attributes that make it behave like an ndarray when passed to - special attributes that make it behave like an ndarray when passed to
numpy functions. numpy functions.
...@@ -681,13 +683,7 @@ class NumpyR(gof.PythonR): ...@@ -681,13 +683,7 @@ class NumpyR(gof.PythonR):
__array__ = property(lambda self: self.data.__array__ ) __array__ = property(lambda self: self.data.__array__ )
__array_struct__ = property(lambda self: self.data.__array_struct__ ) __array_struct__ = property(lambda self: self.data.__array_struct__ )
def set_value(self, value): def set_value_filter(self, value): return numpy.asarray(value)
if value is UNCOMPUTED:
self.data = UNCOMPUTED
else:
self.data = numpy.asarray(value)
self.refresh()
self.up_to_date = True
def set_value_inplace(self, value): def set_value_inplace(self, value):
if value is UNCOMPUTED: if value is UNCOMPUTED:
......
from op import Op from op import Op
from result import Result #, HolderResult from result import Result #, HolderResult
from utils import ClsInit, Keyword from utils import ClsInit, Keyword, AbstractFunctionError
import opt import opt
import env import env
import features import features
...@@ -16,7 +16,7 @@ __all__ = ['UNCOMPUTED', ...@@ -16,7 +16,7 @@ __all__ = ['UNCOMPUTED',
'eval_mode', 'eval_mode',
'build_eval_mode', 'build_eval_mode',
'pop_mode', 'pop_mode',
'PythonR', 'ResultValue',
'DummyOp', 'DummyOp',
'DummyRemover', 'DummyRemover',
'PythonOp', 'PythonOp',
...@@ -27,7 +27,6 @@ __all__ = ['UNCOMPUTED', ...@@ -27,7 +27,6 @@ __all__ = ['UNCOMPUTED',
UNCOMPUTED = Keyword("UNCOMPUTED", False) UNCOMPUTED = Keyword("UNCOMPUTED", False)
UNDEFINED = Keyword("UNDEFINED", False) UNDEFINED = Keyword("UNDEFINED", False)
def make_static(cls, fname): def make_static(cls, fname):
f = getattr(cls, fname) f = getattr(cls, fname)
if hasattr(f, 'im_func'): if hasattr(f, 'im_func'):
...@@ -79,7 +78,26 @@ class ForbidConstantOverwrite(features.Listener, features.Constraint): ...@@ -79,7 +78,26 @@ class ForbidConstantOverwrite(features.Listener, features.Constraint):
class PythonR(Result): class ResultValue(Result):
"""Augment Result to wrap a computed value.
Attributes:
data -
spec -
constant -
up_to_date -
Properties:
Methods:
set_value_filter - ABSTRACT
set_value_inplace - ABSTRACT
alloc - ABSTRACT
Notes:
"""
__slots__ = ['data', 'spec', 'constant', 'up_to_date'] __slots__ = ['data', 'spec', 'constant', 'up_to_date']
...@@ -90,37 +108,57 @@ class PythonR(Result): ...@@ -90,37 +108,57 @@ class PythonR(Result):
self.up_to_date = True self.up_to_date = True
self.spec = None self.spec = None
def __str__(self): return str(self.data)
def __repr__(self): return repr(self.data)
#TODO: document this function, what does it do?
def refresh(self): self.spec = id(self.data)
####################################################
#
# Functionality provided by this class
#
def set_value(self, value): def set_value(self, value):
if self.constant: if self.constant:
raise Exception("This Result is a constant. Its value cannot be changed.") raise Exception("This Result is a constant. Its value cannot be changed.")
if value is None or value is UNCOMPUTED: if value is None or value is UNCOMPUTED:
self.data = UNCOMPUTED self.data = UNCOMPUTED
elif isinstance(value, PythonR): elif isinstance(value, ResultValue):
self.set_value(value.data) self.set_value(value.data)
else: else:
try:
self.data = self.set_value_filter(value)
except AbstractFunctionError, e:
self.data = value self.data = value
self.up_to_date = True self.up_to_date = True
self.refresh() self.refresh()
def set_value_inplace(self, value): def compute(self):
raise NotImplementedError() #HACK: this is potentially very broken behaviour
"""Overrides Op.compute(). Only recurses if self.data is UNCOMPUTED"""
if self.data is UNCOMPUTED:
Result.compute(self)
def __str__(self): ####################################################
return str(self.data) #
# Pure virtual functions for subclasses to implement
#
def __repr__(self): # Perform error checking or automatic conversion of value, and return the
return repr(self.data) # result (which will be stored as self.data)
# Called by: set_value()
def set_value_filter(self, value): raise AbstractFunctionError()
def refresh(self): # For mutable data types, overwrite the current contents with value
self.spec = id(self.data) # Also, call refresh and set up_to_date = True
def set_value_inplace(self, value): raise AbstractFunctionError()
def alloc(self): # Instantiate data (according to spec)
raise TypeError("Cannot allocate following this specification.") def alloc(self): raise AbstractFunctionError()
def compute(self):
"""Overrides Op.compute(). Only recurses if self.data is UNCOMPUTED"""
if self.data is UNCOMPUTED:
self.owner.compute()
class PythonOp(Op): class PythonOp(Op):
...@@ -157,10 +195,10 @@ class PythonOp(Op): ...@@ -157,10 +195,10 @@ class PythonOp(Op):
def __validate__(self): def __validate__(self):
for input in self.inputs: for input in self.inputs:
assert isinstance(input, PythonR) assert isinstance(input, ResultValue)
def gen_outputs(self): def gen_outputs(self):
return [PythonR() for i in xrange(self.nout)] return [ResultValue() for i in xrange(self.nout)]
def root_inputs(self, input): def root_inputs(self, input):
owner = input.owner owner = input.owner
......
...@@ -7,6 +7,7 @@ value that is the input or the output of an Op. ...@@ -7,6 +7,7 @@ value that is the input or the output of an Op.
from err import GofError from err import GofError
from utils import AbstractFunctionError
__all__ = ['Result', 'BrokenLink', 'BrokenLinkError'] __all__ = ['Result', 'BrokenLink', 'BrokenLinkError']
...@@ -40,13 +41,17 @@ class BrokenLinkError(GofError): ...@@ -40,13 +41,17 @@ class BrokenLinkError(GofError):
############################ ############################
class Result(object): class Result(object):
""" """Storage node for data in a graph of Op instances.
The Result class represents a datum for use in a graph of Ops. It
has two slots:
- owner: represents the Op which computes this Result. Contains either None Attributes:
owner - represents the Op which computes this Result. Contains either None
or an instance of Op. or an instance of Op.
- index: the index of this Result in owner.outputs. index - the index of this Result in owner.outputs.
Methods:
-
Notes:
Result has no __init__ or __new__ routine. It is the Op's Result has no __init__ or __new__ routine. It is the Op's
responsibility to set the owner field of its results. responsibility to set the owner field of its results.
...@@ -70,7 +75,8 @@ class Result(object): ...@@ -70,7 +75,8 @@ class Result(object):
self._owner = None self._owner = None
return self._owner return self._owner
owner = property(get_owner, doc = "The Op of which this Result is an output or None if there is no such Op.") owner = property(get_owner,
doc = "The Op of which this Result is an output or None if there is no such Op.")
def set_owner(self, owner, index): def set_owner(self, owner, index):
if self.owner is not None: if self.owner is not None:
...@@ -94,21 +100,21 @@ class Result(object): ...@@ -94,21 +100,21 @@ class Result(object):
self._owner = owner self._owner = owner
self._index = index self._index = index
def compute(self):
"""If self has an owner, recursively compute it.
def set_value(self, value): This is a mutually recursive function with gof.op.Op
"""
Copies the provided value in this Result. It is not required to
implement this method.
"""
raise NotImplementedError("This Result does not support set_value.")
def compute(self): """
"""If self has an owner, recursively compute it."""
if self.owner: if self.owner:
self.owner.compute() self.owner.compute()
def perform(self): def perform(self):
"""Calls self.owner.perform() if self.owner exists.""" """Calls self.owner.perform() if self.owner exists.
This is a mutually recursive function with gof.op.Op
"""
if self.owner: if self.owner:
self.owner.perform() self.owner.perform()
......
...@@ -3,9 +3,14 @@ ...@@ -3,9 +3,14 @@
# import result # import result
class OmegaError(Exception): class OmegaError(Exception): pass
pass
class AbstractFunctionError(Exception):
"""To be raised by functions defined as part of an interface.
When the user sees such an error, it is because an important interface
function has been left out of an implementation class.
"""
def all_bases(cls, accept): def all_bases(cls, accept):
......
...@@ -156,18 +156,18 @@ class update_gradient_via_grad: ...@@ -156,18 +156,18 @@ class update_gradient_via_grad:
self.grad(*(self.inputs + [grad_d[output] for output in self.outputs])) self.grad(*(self.inputs + [grad_d[output] for output in self.outputs]))
In general, grad() should return a list of PythonR instances whose In general, grad() should return a list of ResultValue instances whose
length matches that of self.inputs, and whose elements are the length matches that of self.inputs, and whose elements are the
gradients of self.inputs. gradients of self.inputs.
There is a (but often used) special feature in place to automatically There is a (but often used) special feature in place to automatically
wrap the return value of grad() in a list if it is a PythonR instance wrap the return value of grad() in a list if it is a ResultValue instance
and the op is unary. This makes many grad implementations a little and the op is unary. This makes many grad implementations a little
cuter. cuter.
""" """
inputgs = self.grad(*(self.inputs + [grad_d[output] for output in self.outputs])) inputgs = self.grad(*(self.inputs + [grad_d[output] for output in self.outputs]))
if len(self.inputs) == 1 and isinstance(inputgs, gof.PythonR): if len(self.inputs) == 1 and isinstance(inputgs, gof.ResultValue):
inputgs = [inputgs] inputgs = [inputgs]
else: else:
assert len(inputgs) == len(self.inputs) assert len(inputgs) == len(self.inputs)
......
...@@ -16,7 +16,7 @@ class RandomState(gof.Op, gof.ext.IONames): ...@@ -16,7 +16,7 @@ class RandomState(gof.Op, gof.ext.IONames):
def __init__(self, seed): def __init__(self, seed):
inputs = [wrap(seed)] inputs = [wrap(seed)]
outputs = [PythonR()] outputs = [ResultValue()]
gof.Op.__init__(self, inputs, outputs) gof.Op.__init__(self, inputs, outputs)
def thunk(self): def thunk(self):
......
...@@ -9,24 +9,27 @@ import grad ...@@ -9,24 +9,27 @@ import grad
# Wrapper type # Wrapper type
class SparseR(gof.PythonR): class SparseR(gof.ResultValue):
""" """
Attribute: Attribute:
format - a subclass of sparse.spmatrix indicating self.data.__class__ format - a subclass of sparse.spmatrix indicating self.data.__class__
Properties:
T - read-only: return a transpose of self
Methods:
Notes:
""" """
def __init__(self, x = core.UNCOMPUTED, constant = False, def __init__(self, x = core.UNCOMPUTED, constant = False,
format = sparse.csr_matrix): format = sparse.csr_matrix):
gof.PythonR.__init__(self, x, constant) gof.ResultValue.__init__(self, x, constant)
self.format = isinstance(x, sparse.spmatrix) and x.__class__ or format self.format = isinstance(x, sparse.spmatrix) and x.__class__ or format
def set_value(self, value): def set_value_filter(self, value):
"""Extend base impl, assert value is sparse matrix""" if isinstance(value, sparse.spmatrix): return value
gof.PythonR.set_value(self,value) return sparse.csr_matrix(value)
if self.data is not core.UNCOMPUTED:
if not isinstance(self.data, sparse.spmatrix):
print self.data.__class__
print self.owner.__class__
raise TypeError(('hrm',value))
def __add__(left, right): return add(left, right) def __add__(left, right): return add(left, right)
def __radd__(right, left): return add(left, right) def __radd__(right, left): return add(left, right)
...@@ -148,11 +151,11 @@ class _testCase_dot(unittest.TestCase): ...@@ -148,11 +151,11 @@ class _testCase_dot(unittest.TestCase):
m = mtype(a) m = mtype(a)
ab = m.dot(b) ab = m.dot(b)
try: try:
z = dot(SparseR(m),gof.lib.PythonR(b)) z = dot(SparseR(m),gof.lib.ResultValue(b))
self.failUnless(z.data.shape == ab.shape) self.failUnless(z.data.shape == ab.shape)
self.failUnless(type(z.data) == type(ab)) self.failUnless(type(z.data) == type(ab))
except Exception, e: except Exception, e:
print mtype, e, str(e) print 'cccc', mtype, e, str(e)
raise raise
def test_basic2(self): def test_basic2(self):
"""dot: sparse right""" """dot: sparse right"""
...@@ -164,7 +167,7 @@ class _testCase_dot(unittest.TestCase): ...@@ -164,7 +167,7 @@ class _testCase_dot(unittest.TestCase):
sparse.lil_matrix]:#, sparse.coo_matrix]: sparse.lil_matrix]:#, sparse.coo_matrix]:
m = mtype(b) m = mtype(b)
ab = m.transpose().dot(a.transpose()).transpose() ab = m.transpose().dot(a.transpose()).transpose()
z = dot(gof.lib.PythonR(a),SparseR(mtype(b))) z = dot(gof.lib.ResultValue(a),SparseR(mtype(b)))
self.failUnless(z.data.shape == ab.shape) self.failUnless(z.data.shape == ab.shape)
self.failUnless(type(z.data) == type(ab)) self.failUnless(type(z.data) == type(ab))
def test_graph_bprop0(self): def test_graph_bprop0(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论