m

上级 efc14dba
......@@ -11,7 +11,8 @@ import numpy
from scipy import weave
import gof
from gof import current_mode, set_mode, build_mode, eval_mode, build_eval_mode, pop_mode, UNCOMPUTED, UNDEFINED, PythonR
from gof import current_mode, set_mode, build_mode, eval_mode, build_eval_mode
from gof import pop_mode, UNCOMPUTED, UNDEFINED, ResultValue
import type_spec
import cutils
......@@ -105,12 +106,12 @@ def input(x):
elif isinstance(x, gof.Result):
raise TypeError("%s is already a result." % x)
else:
return PythonR(x)
return ResultValue(x)
def wrap(x):
if isinstance(x, NumpyR):
return x
elif isinstance(x, PythonR):
elif isinstance(x, ResultValue):
return x
elif isinstance(x, omega_op):
return x.out
......@@ -144,7 +145,7 @@ def _literal_unhashable(x):
return r
def literal(x):
"""Return a PythonR instance wrapping a literal."""
"""Return a ResultValue instance wrapping a literal."""
if _hashable(x):
return _literal_hashable(x)
else:
......@@ -253,18 +254,18 @@ class omega_op(gof.PythonOp):
self.grad(*(self.inputs + [grad_d[output] for output in self.outputs]))
In general, grad() should return a list of PythonR instances whose
In general, grad() should return a list of ResultValue instances whose
length matches that of self.inputs, and whose elements are the
gradients of self.inputs.
There is a (but often used) special feature in place to automatically
wrap the return value of grad() in a list if it is a PythonR instance
wrap the return value of grad() in a list if it is a ResultValue instance
and the op is unary. This makes many grad implementations a little
cuter.
"""
inputgs = self.grad(*(self.inputs + [grad_d[output] for output in self.outputs]))
if len(self.inputs) == 1 and isinstance(inputgs, gof.PythonR):
if len(self.inputs) == 1 and isinstance(inputgs, gof.ResultValue):
inputgs = [inputgs]
else:
assert len(inputgs) == len(self.inputs)
......@@ -660,9 +661,10 @@ def scalar_switch(normal_f, scalar_f, scalar_f_reverse = None):
return normal_f(x, y)
return f
class NumpyR(gof.PythonR):
class NumpyR(gof.ResultValue):
"""The class for storing ndarray return values from omega ops.
The class provides additional functionality compared to the normal PythonR:
The class provides additional functionality compared to the normal
ResultValue:
- operator overloads that correspond to omega ops such as add() and scale()
- special attributes that make it behave like an ndarray when passed to
numpy functions.
......@@ -681,13 +683,7 @@ class NumpyR(gof.PythonR):
__array__ = property(lambda self: self.data.__array__ )
__array_struct__ = property(lambda self: self.data.__array_struct__ )
def set_value(self, value):
if value is UNCOMPUTED:
self.data = UNCOMPUTED
else:
self.data = numpy.asarray(value)
self.refresh()
self.up_to_date = True
def set_value_filter(self, value): return numpy.asarray(value)
def set_value_inplace(self, value):
if value is UNCOMPUTED:
......
from op import Op
from result import Result #, HolderResult
from utils import ClsInit, Keyword
from utils import ClsInit, Keyword, AbstractFunctionError
import opt
import env
import features
......@@ -16,7 +16,7 @@ __all__ = ['UNCOMPUTED',
'eval_mode',
'build_eval_mode',
'pop_mode',
'PythonR',
'ResultValue',
'DummyOp',
'DummyRemover',
'PythonOp',
......@@ -27,7 +27,6 @@ __all__ = ['UNCOMPUTED',
UNCOMPUTED = Keyword("UNCOMPUTED", False)
UNDEFINED = Keyword("UNDEFINED", False)
def make_static(cls, fname):
f = getattr(cls, fname)
if hasattr(f, 'im_func'):
......@@ -79,7 +78,26 @@ class ForbidConstantOverwrite(features.Listener, features.Constraint):
class PythonR(Result):
class ResultValue(Result):
"""Augment Result to wrap a computed value.
Attributes:
data -
spec -
constant -
up_to_date -
Properties:
Methods:
set_value_filter - ABSTRACT
set_value_inplace - ABSTRACT
alloc - ABSTRACT
Notes:
"""
__slots__ = ['data', 'spec', 'constant', 'up_to_date']
......@@ -90,37 +108,57 @@ class PythonR(Result):
self.up_to_date = True
self.spec = None
def __str__(self): return str(self.data)
def __repr__(self): return repr(self.data)
#TODO: document this function, what does it do?
def refresh(self): self.spec = id(self.data)
####################################################
#
# Functionality provided by this class
#
def set_value(self, value):
if self.constant:
raise Exception("This Result is a constant. Its value cannot be changed.")
if value is None or value is UNCOMPUTED:
self.data = UNCOMPUTED
elif isinstance(value, PythonR):
elif isinstance(value, ResultValue):
self.set_value(value.data)
else:
self.data = value
try:
self.data = self.set_value_filter(value)
except AbstractFunctionError, e:
self.data = value
self.up_to_date = True
self.refresh()
def set_value_inplace(self, value):
raise NotImplementedError()
def compute(self):
#HACK: this is potentially very broken behaviour
"""Overrides Op.compute(). Only recurses if self.data is UNCOMPUTED"""
if self.data is UNCOMPUTED:
Result.compute(self)
def __str__(self):
return str(self.data)
####################################################
#
# Pure virtual functions for subclasses to implement
#
def __repr__(self):
return repr(self.data)
# Perform error checking or automatic conversion of value, and return the
# result (which will be stored as self.data)
# Called by: set_value()
def set_value_filter(self, value): raise AbstractFunctionError()
def refresh(self):
self.spec = id(self.data)
# For mutable data types, overwrite the current contents with value
# Also, call refresh and set up_to_date = True
def set_value_inplace(self, value): raise AbstractFunctionError()
def alloc(self):
raise TypeError("Cannot allocate following this specification.")
# Instantiate data (according to spec)
def alloc(self): raise AbstractFunctionError()
def compute(self):
"""Overrides Op.compute(). Only recurses if self.data is UNCOMPUTED"""
if self.data is UNCOMPUTED:
self.owner.compute()
class PythonOp(Op):
......@@ -157,10 +195,10 @@ class PythonOp(Op):
def __validate__(self):
for input in self.inputs:
assert isinstance(input, PythonR)
assert isinstance(input, ResultValue)
def gen_outputs(self):
return [PythonR() for i in xrange(self.nout)]
return [ResultValue() for i in xrange(self.nout)]
def root_inputs(self, input):
owner = input.owner
......
......@@ -7,6 +7,7 @@ value that is the input or the output of an Op.
from err import GofError
from utils import AbstractFunctionError
__all__ = ['Result', 'BrokenLink', 'BrokenLinkError']
......@@ -40,13 +41,17 @@ class BrokenLinkError(GofError):
############################
class Result(object):
"""
The Result class represents a datum for use in a graph of Ops. It
has two slots:
"""Storage node for data in a graph of Op instances.
Attributes:
owner - represents the Op which computes this Result. Contains either None
or an instance of Op.
index - the index of this Result in owner.outputs.
Methods:
-
- owner: represents the Op which computes this Result. Contains either None
or an instance of Op.
- index: the index of this Result in owner.outputs.
Notes:
Result has no __init__ or __new__ routine. It is the Op's
responsibility to set the owner field of its results.
......@@ -70,7 +75,8 @@ class Result(object):
self._owner = None
return self._owner
owner = property(get_owner, doc = "The Op of which this Result is an output or None if there is no such Op.")
owner = property(get_owner,
doc = "The Op of which this Result is an output or None if there is no such Op.")
def set_owner(self, owner, index):
if self.owner is not None:
......@@ -94,21 +100,21 @@ class Result(object):
self._owner = owner
self._index = index
def compute(self):
"""If self has an owner, recursively compute it.
def set_value(self, value):
"""
Copies the provided value in this Result. It is not required to
implement this method.
"""
raise NotImplementedError("This Result does not support set_value.")
This is a mutually recursive function with gof.op.Op
def compute(self):
"""If self has an owner, recursively compute it."""
"""
if self.owner:
self.owner.compute()
def perform(self):
"""Calls self.owner.perform() if self.owner exists."""
"""Calls self.owner.perform() if self.owner exists.
This is a mutually recursive function with gof.op.Op
"""
if self.owner:
self.owner.perform()
......
......@@ -3,9 +3,14 @@
# import result
class OmegaError(Exception):
pass
class OmegaError(Exception): pass
class AbstractFunctionError(Exception):
"""To be raised by functions defined as part of an interface.
When the user sees such an error, it is because an important interface
function has been left out of an implementation class.
"""
def all_bases(cls, accept):
......
......@@ -156,18 +156,18 @@ class update_gradient_via_grad:
self.grad(*(self.inputs + [grad_d[output] for output in self.outputs]))
In general, grad() should return a list of PythonR instances whose
In general, grad() should return a list of ResultValue instances whose
length matches that of self.inputs, and whose elements are the
gradients of self.inputs.
There is a (but often used) special feature in place to automatically
wrap the return value of grad() in a list if it is a PythonR instance
wrap the return value of grad() in a list if it is a ResultValue instance
and the op is unary. This makes many grad implementations a little
cuter.
"""
inputgs = self.grad(*(self.inputs + [grad_d[output] for output in self.outputs]))
if len(self.inputs) == 1 and isinstance(inputgs, gof.PythonR):
if len(self.inputs) == 1 and isinstance(inputgs, gof.ResultValue):
inputgs = [inputgs]
else:
assert len(inputgs) == len(self.inputs)
......
......@@ -16,7 +16,7 @@ class RandomState(gof.Op, gof.ext.IONames):
def __init__(self, seed):
inputs = [wrap(seed)]
outputs = [PythonR()]
outputs = [ResultValue()]
gof.Op.__init__(self, inputs, outputs)
def thunk(self):
......
......@@ -9,24 +9,27 @@ import grad
# Wrapper type
class SparseR(gof.PythonR):
class SparseR(gof.ResultValue):
"""
Attribute:
format - a subclass of sparse.spmatrix indicating self.data.__class__
Properties:
T - read-only: return a transpose of self
Methods:
Notes:
"""
def __init__(self, x = core.UNCOMPUTED, constant = False,
format = sparse.csr_matrix):
gof.PythonR.__init__(self, x, constant)
gof.ResultValue.__init__(self, x, constant)
self.format = isinstance(x, sparse.spmatrix) and x.__class__ or format
def set_value(self, value):
"""Extend base impl, assert value is sparse matrix"""
gof.PythonR.set_value(self,value)
if self.data is not core.UNCOMPUTED:
if not isinstance(self.data, sparse.spmatrix):
print self.data.__class__
print self.owner.__class__
raise TypeError(('hrm',value))
def set_value_filter(self, value):
if isinstance(value, sparse.spmatrix): return value
return sparse.csr_matrix(value)
def __add__(left, right): return add(left, right)
def __radd__(right, left): return add(left, right)
......@@ -148,11 +151,11 @@ class _testCase_dot(unittest.TestCase):
m = mtype(a)
ab = m.dot(b)
try:
z = dot(SparseR(m),gof.lib.PythonR(b))
z = dot(SparseR(m),gof.lib.ResultValue(b))
self.failUnless(z.data.shape == ab.shape)
self.failUnless(type(z.data) == type(ab))
except Exception, e:
print mtype, e, str(e)
print 'cccc', mtype, e, str(e)
raise
def test_basic2(self):
"""dot: sparse right"""
......@@ -164,7 +167,7 @@ class _testCase_dot(unittest.TestCase):
sparse.lil_matrix]:#, sparse.coo_matrix]:
m = mtype(b)
ab = m.transpose().dot(a.transpose()).transpose()
z = dot(gof.lib.PythonR(a),SparseR(mtype(b)))
z = dot(gof.lib.ResultValue(a),SparseR(mtype(b)))
self.failUnless(z.data.shape == ab.shape)
self.failUnless(type(z.data) == type(ab))
def test_graph_bprop0(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论