broken

上级 460f6b78
import time import time, unittest
import numpy
import gof import gof
import gof.lib import gof.lib
...@@ -170,13 +171,13 @@ class prog(gof.Prog): ...@@ -170,13 +171,13 @@ class prog(gof.Prog):
""" """
if check_uncomputed: if check_uncomputed:
for input in self.env.inputs: for input in self.env.inputs:
if input.data is core.UNCOMPUTED: if input.data is None:
raise Exception("You must provide a value for input %s!" % input) raise Exception("You must provide a value for input %s!" % input)
return gof.Prog.__call__(self) return gof.Prog.__call__(self)
def compute_orphans(self): def compute_orphans(self):
for orphan in self.env.orphans(): for orphan in self.env.orphans():
if orphan.data is core.UNCOMPUTED: if orphan.data is None:
if orphan.owner: if orphan.owner:
gof.lib.compute(orphan.owner) gof.lib.compute(orphan.owner)
else: else:
...@@ -200,3 +201,22 @@ def to_func(inputs, outputs): ...@@ -200,3 +201,22 @@ def to_func(inputs, outputs):
def single(*outputs, **kwargs): def single(*outputs, **kwargs):
return prog(gof.graph.inputs(outputs), outputs, **kwargs) return prog(gof.graph.inputs(outputs), outputs, **kwargs)
class _test_single(unittest.TestCase):
def setUp(self):
core.build_eval_mode()
numpy.random.seed(44)
def tearDown(self):
core.pop_mode()
def test_3(self):
a = core.Numpy2(data=numpy.ones((2,2)))
b = core.Numpy2(data=numpy.ones((2,2)))
c = core.add(a,b)
p = single(c)
p()
self.failUnless(core._approx_eq(c, numpy.ones((2,2))*2))
if __name__ == '__main__':
unittest.main()
...@@ -12,7 +12,7 @@ from scipy import weave ...@@ -12,7 +12,7 @@ from scipy import weave
import gof import gof
from gof import current_mode, set_mode, build_mode, eval_mode, build_eval_mode from gof import current_mode, set_mode, build_mode, eval_mode, build_eval_mode
from gof import pop_mode, UNCOMPUTED, UNDEFINED, ResultValue from gof import pop_mode, UNDEFINED, is_result
import type_spec import type_spec
import cutils import cutils
...@@ -45,7 +45,7 @@ def _approx_eq(a,b,eps=1.0e-9): ...@@ -45,7 +45,7 @@ def _approx_eq(a,b,eps=1.0e-9):
b = numpy.asarray(b) b = numpy.asarray(b)
if a.shape != b.shape: if a.shape != b.shape:
return False return False
return numpy.max( numpy.abs(a-b)) < eps return numpy.max(numpy.abs(a-b)) < eps
# This function is only executed the first time it is called, subsequent calls # This function is only executed the first time it is called, subsequent calls
...@@ -84,7 +84,7 @@ def _compile_dir(): ...@@ -84,7 +84,7 @@ def _compile_dir():
sys.path.append(cachedir) sys.path.append(cachedir)
return cachedir return cachedir
class ResultBase: class ResultBase(object):
"""Base class for storing Op inputs and outputs """Base class for storing Op inputs and outputs
Attributes: Attributes:
...@@ -98,6 +98,7 @@ class ResultBase: ...@@ -98,6 +98,7 @@ class ResultBase:
index - (ro) index - (ro)
data - (rw) data - (rw)
replaced - (rw) : True iff _role is BrokenLink replaced - (rw) : True iff _role is BrokenLink
computed - (ro) : True iff contents of data are fresh
Abstract Methods: Abstract Methods:
data_filter data_filter
...@@ -170,11 +171,12 @@ class ResultBase: ...@@ -170,11 +171,12 @@ class ResultBase:
if self.replaced: raise ResultBase.BrokenLinkError() if self.replaced: raise ResultBase.BrokenLinkError()
if self.constant: raise Exception('cannot set constant ResultBase') if self.constant: raise Exception('cannot set constant ResultBase')
try: try:
self._data = self.data_filter(self, data) self._data = self.data_filter(data)
except ResultBase.AbstractFunction: except ResultBase.AbstractFunction: #use default behaviour
self._data = data self._data = data
data = property(__get_data, __set_data, data = property(__get_data, __set_data,
doc = "The storage associated with this result") doc = "The storage associated with this result")
def data_filter(self, data): def data_filter(self, data):
"""(abstract) Return an appropriate _data based on data.""" """(abstract) Return an appropriate _data based on data."""
raise ResultBase.AbstractFunction() raise ResultBase.AbstractFunction()
...@@ -189,7 +191,28 @@ class ResultBase: ...@@ -189,7 +191,28 @@ class ResultBase:
else: else:
self._role = self._role.old_role self._role = self._role.old_role
replaced = property(__get_replaced, __set_replaced, doc = "has this Result been replaced?") replaced = property(__get_replaced, __set_replaced, doc = "has this Result been replaced?")
# computed
#TODO: think about how to handle this more correctly
computed = property(lambda self: self._data is not None)
#################
# NumpyR Compatibility
#
up_to_date = property(lambda self: True)
def refresh(self): pass
def set_owner(self, owner, idx):
self.role = (owner, idx)
def set_value(self, value):
self.data = value #may raise exception
class _test_ResultBase(unittest.TestCase): class _test_ResultBase(unittest.TestCase):
def setUp(self):
build_eval_mode()
numpy.random.seed(44)
def tearDown(self):
pop_mode()
def test_0(self): def test_0(self):
r = ResultBase() r = ResultBase()
...@@ -213,10 +236,13 @@ class Numpy2(ResultBase): ...@@ -213,10 +236,13 @@ class Numpy2(ResultBase):
# ResultBase # ResultBase
# #
def data_filter(self, data): def data_filter(self, data):
#return numpy.asarray(data) #TODO: consider whether this is correct #TODO: decide which of these implementations is better
if 0:
if isinstance(data, numpy.ndarray): if isinstance(data, numpy.ndarray):
return data return data
raise TypeError('failed to filter data to ndarray', data) raise TypeError('failed to filter data to ndarray', data)
else:
return numpy.asarray(data)
################################ ################################
...@@ -256,7 +282,53 @@ class Numpy2(ResultBase): ...@@ -256,7 +282,53 @@ class Numpy2(ResultBase):
else: else:
raise StateError('cannot set shape after data has been set') raise StateError('cannot set shape after data has been set')
shape = property(__shape_get, __shape_set) shape = property(__shape_get, __shape_set)
def __add__(self, y): return add(self, y)
def __radd__(self, x): return add(x, self)
def __iadd__(self, y): return add_inplace(self, y)
def __sub__(self, y): return sub(self, y)
def __rsub__(self, x): return sub(x, self)
def __isub__(self, y): return sub_inplace(self, y)
def __mul__(self, y): return mul(self, y)
def __rmul__(self, x): return mul(x, self)
def __imul__(self, y): return mul_inplace(self, y)
def __div__(self, y): return div(self, y)
def __rdiv__(self, x): return div(x, self)
def __idiv__(self, y): return div_inplace(self, y)
def __pow__(self, y): return pow(self, y)
def __rpow__(self, x): return pow(x, self)
def __ipow__(self, y): return pow_inplace(self, y)
def __neg__(self): return neg(self)
T = property(lambda self: transpose(self))
Tc = property(lambda self: transpose_copy(self))
def __copy__(self): return array_copy(self)
def __getitem__(self, item): return get_slice(self, item)
def __getslice__(self, *args): return get_slice(self, slice(*args))
#################
# NumpyR Compatibility
#
spec = property(lambda self: (numpy.ndarray, self.dtype, self.shape))
def set_value_inplace(self, value):
if 0 == len(self.shape):
self.data.itemset(value) # for scalars
else:
self.data[:] = value # for matrices
class _test_Numpy2(unittest.TestCase): class _test_Numpy2(unittest.TestCase):
def setUp(self):
build_eval_mode()
numpy.random.seed(44)
def tearDown(self):
pop_mode()
def test_0(self): def test_0(self):
r = Numpy2() r = Numpy2()
def test_1(self): def test_1(self):
...@@ -265,6 +337,7 @@ class _test_Numpy2(unittest.TestCase): ...@@ -265,6 +337,7 @@ class _test_Numpy2(unittest.TestCase):
self.failUnless(r.data is o) self.failUnless(r.data is o)
self.failUnless(r.shape == (3,3)) self.failUnless(r.shape == (3,3))
self.failUnless(str(r.dtype) == 'float64') self.failUnless(str(r.dtype) == 'float64')
def test_2(self): def test_2(self):
r = Numpy2(data=[(3,3),'int32']) r = Numpy2(data=[(3,3),'int32'])
self.failUnless(r.data is None) self.failUnless(r.data is None)
...@@ -275,27 +348,45 @@ class _test_Numpy2(unittest.TestCase): ...@@ -275,27 +348,45 @@ class _test_Numpy2(unittest.TestCase):
self.failUnless(r.shape == (3,3)) self.failUnless(r.shape == (3,3))
self.failUnless(str(r.dtype) == 'int32') self.failUnless(str(r.dtype) == 'int32')
def test_3(self):
a = Numpy2(data=numpy.ones((2,2)))
b = Numpy2(data=numpy.ones((2,2)))
c = add(a,b)
self.failUnless(_approx_eq(c, numpy.ones((2,2))*2))
def test_4(self):
ones = numpy.ones((2,2))
a = Numpy2(data=ones)
o = numpy.asarray(a)
self.failUnless((ones == o).all())
def test_5(self):
ones = numpy.ones((2,2))
self.failUnless(_approx_eq(Numpy2(data=ones), Numpy2(data=ones)))
def input(x): def input(x):
#static member initialization #static member initialization
if not hasattr(input, 'float_dtype'): if not hasattr(input, 'float_dtype'):
input.float_dtype = 'float64' input.float_dtype = 'float64'
input.int_dtype = 'int64' input.int_dtype = 'int64'
input.NN = NumpyR input.NN = Numpy2
if isinstance(x, numpy.ndarray): if isinstance(x, numpy.ndarray):
return input.NN(x) #return NumpyR(x)
return input.NN(data=x)
elif isinstance(x, int): elif isinstance(x, int):
z = numpy.zeros((), dtype = input.int_dtype) z = numpy.zeros((), dtype = input.int_dtype)
z += x z += x
return input.NN(z) return input.NN(data=z)
elif isinstance(x, float): elif isinstance(x, float):
z = numpy.zeros((), dtype = input.float_dtype) z = numpy.zeros((), dtype = input.float_dtype)
z += x z += x
return input.NN(z) return input.NN(data=z)
elif isinstance(x, gof.Result): elif is_result(x):
raise TypeError("%s is already a result." % x) raise TypeError("%s is already a result." % x)
else: else:
return ResultValue(x) return ResultBase(x)
class _testCase_input(unittest.TestCase): class _testCase_input(unittest.TestCase):
def setUp(self): def setUp(self):
literal.hdb = {} literal.hdb = {}
...@@ -312,11 +403,11 @@ class _testCase_input(unittest.TestCase): ...@@ -312,11 +403,11 @@ class _testCase_input(unittest.TestCase):
self.failUnless(w.data == 3.0) self.failUnless(w.data == 3.0)
def wrap(x): def wrap(x):
if isinstance(x, NumpyR): if isinstance(x, Numpy2):
return x
elif isinstance(x, Numpy2):
return x return x
elif isinstance(x, ResultValue): #elif isinstance(x, NumpyR):
#return x
elif is_result(x):
return x return x
elif isinstance(x, omega_op): elif isinstance(x, omega_op):
return x.out return x.out
...@@ -482,7 +573,7 @@ class omega_op(gof.PythonOp): ...@@ -482,7 +573,7 @@ class omega_op(gof.PythonOp):
return gof.PythonOp.__new__(cls, *inputs) return gof.PythonOp.__new__(cls, *inputs)
def gen_outputs(self): def gen_outputs(self):
return [NumpyR() for i in xrange(self.nout)] return [Numpy2() for i in xrange(self.nout)]
def update_gradient(self, grad_d): def update_gradient(self, grad_d):
"""Call self.grad() and add the result to grad_d """Call self.grad() and add the result to grad_d
...@@ -504,7 +595,7 @@ class omega_op(gof.PythonOp): ...@@ -504,7 +595,7 @@ class omega_op(gof.PythonOp):
""" """
inputgs = self.grad(*(self.inputs + [grad_d[output] for output in self.outputs])) inputgs = self.grad(*(self.inputs + [grad_d[output] for output in self.outputs]))
if len(self.inputs) == 1 and isinstance(inputgs, gof.ResultValue): if len(self.inputs) == 1 and gof.result.is_result(inputgs):
inputgs = [inputgs] inputgs = [inputgs]
else: else:
assert len(inputgs) == len(self.inputs) assert len(inputgs) == len(self.inputs)
...@@ -535,6 +626,11 @@ class omega_op(gof.PythonOp): ...@@ -535,6 +626,11 @@ class omega_op(gof.PythonOp):
def c_impl(inputs, outputs): def c_impl(inputs, outputs):
raise NotImplementedError() raise NotImplementedError()
def c_compile_args(self):
# I always used these, but they don't make much improvement
#'-ffast-math', '-falign-loops=8'
return ['-O2']
def c_thunk_factory(self): def c_thunk_factory(self):
self.refresh() self.refresh()
d, names, code, struct, converters = self.c_code() d, names, code, struct, converters = self.c_code()
...@@ -549,10 +645,8 @@ class omega_op(gof.PythonOp): ...@@ -549,10 +645,8 @@ class omega_op(gof.PythonOp):
global_dict = {}, global_dict = {},
type_converters = converters) type_converters = converters)
instantiate.customize.add_support_code(self.c_support_code() + struct) instantiate.customize.add_support_code(self.c_support_code() + struct)
instantiate.customize.add_extra_compile_arg("-O3") for arg in self.c_compile_args():
instantiate.customize.add_extra_compile_arg("-ffast-math") #TODO: make this optional, say by passing args to c_thunk_factory? instantiate.customize.add_extra_compile_arg(arg)
instantiate.customize.add_extra_compile_arg("-falign-loops=4")
# instantiate.customize.add_extra_compile_arg("-mfpmath=sse")
for header in self.c_headers(): for header in self.c_headers():
instantiate.customize.add_header(header) instantiate.customize.add_header(header)
for lib in self.c_libs(): for lib in self.c_libs():
...@@ -725,7 +819,7 @@ class elemwise(omega_op): ...@@ -725,7 +819,7 @@ class elemwise(omega_op):
try: try:
dtype = upcast(*[input.spec[1] dtype = upcast(*[input.spec[1]
for iname, input in zip(inames, self.inputs) for iname, input in zip(inames, self.inputs)
if isinstance(input, NumpyR)]) if input.spec[0] is numpy.ndarray])
except IndexError: except IndexError:
raise Exception("not all numpy inputs are specified") raise Exception("not all numpy inputs are specified")
...@@ -895,7 +989,8 @@ def scalar_switch(normal_f, scalar_f, scalar_f_reverse = None): ...@@ -895,7 +989,8 @@ def scalar_switch(normal_f, scalar_f, scalar_f_reverse = None):
return normal_f(x, y) return normal_f(x, y)
return f return f
class NumpyR(gof.ResultValue): if 0:
class NumpyR(gof.ResultValue):
"""The class for storing ndarray return values from omega ops. """The class for storing ndarray return values from omega ops.
The class provides additional functionality compared to the normal The class provides additional functionality compared to the normal
ResultValue: ResultValue:
...@@ -1215,7 +1310,10 @@ class dot(omega_op): ...@@ -1215,7 +1310,10 @@ class dot(omega_op):
def c_impl((_x, _y), (_z, )): def c_impl((_x, _y), (_z, )):
return blas.gemm_code('', '1.0', '0.0') return blas.gemm_code('', '1.0', '0.0')
class _testCase_dot(unittest.TestCase): if 0:
print 'SKIPPING DOT TESTS'
else:
class _testCase_dot(unittest.TestCase):
def setUp(self): def setUp(self):
build_eval_mode() build_eval_mode()
numpy.random.seed(44) numpy.random.seed(44)
...@@ -1640,14 +1738,22 @@ class _testCase_power(unittest.TestCase): ...@@ -1640,14 +1738,22 @@ class _testCase_power(unittest.TestCase):
numpy.random.seed(44) numpy.random.seed(44)
def tearDown(self): def tearDown(self):
pop_mode() pop_mode()
def test1(self):
r = numpy.random.rand(50)
exp_r = exp(r)
self.failUnless(exp_r.__array__().__class__ is numpy.ndarray)
def test_0(self): def test_0(self):
r = numpy.random.rand(50) r = numpy.random.rand(50)
exp_r = exp(r) exp_r = exp(r)
self.failUnless( _approx_eq(exp_r, numpy.exp(r))) n_exp_r = numpy.exp(r)
self.failUnless( _approx_eq(exp_r, n_exp_r),
(exp_r, exp_r.data, n_exp_r,
numpy.max(numpy.abs(n_exp_r.__sub__(exp_r.__array__())))))
log_exp_r = log(exp_r) log_exp_r = log(exp_r)
self.failUnless( _approx_eq(log_exp_r, r)) self.failUnless( _approx_eq(log_exp_r, r), log_exp_r)
def test_1(self): def test_1(self):
r = numpy.random.rand(50) r = numpy.random.rand(50)
......
...@@ -5,7 +5,7 @@ import graph ...@@ -5,7 +5,7 @@ import graph
from utils import ClsInit from utils import ClsInit
from err import GofError, GofTypeError, PropagationError from err import GofError, GofTypeError, PropagationError
from op import Op from op import Op
from result import Result from result import is_result
from features import Listener, Orderings, Constraint, Tool from features import Listener, Orderings, Constraint, Tool
import utils import utils
...@@ -14,10 +14,10 @@ __all__ = ['InconsistencyError', ...@@ -14,10 +14,10 @@ __all__ = ['InconsistencyError',
# class AliasDict(dict): # class AliasDict(dict):
# "Utility class to keep track of what Result has been replaced with what Result." # "Utility class to keep track of what result has been replaced with what result."
# def group(self, main, *keys): # def group(self, main, *keys):
# "Marks all the keys as having been replaced by the Result main." # "Marks all the keys as having been replaced by the result main."
# keys = [key for key in keys if key is not main] # keys = [key for key in keys if key is not main]
# if self.has_key(main): # if self.has_key(main):
# raise Exception("Only group results that have not been grouped before.") # raise Exception("Only group results that have not been grouped before.")
...@@ -284,9 +284,11 @@ class Env(graph.Graph): ...@@ -284,9 +284,11 @@ class Env(graph.Graph):
be raised if there are type mismatches. be raised if there are type mismatches.
""" """
# Assert that they are Result instances. # Assert that they are result instances.
assert isinstance(r, Result) if not is_result(r):
assert isinstance(new_r, Result) raise TypeError(r)
if not is_result(new_r):
raise TypeError(new_r)
# Save where we are so we can backtrack # Save where we are so we can backtrack
if consistency_check: if consistency_check:
......
from op import Op from op import Op
from result import Result #, HolderResult from result import Result, is_result
from utils import ClsInit, Keyword, AbstractFunctionError from utils import ClsInit, Keyword, AbstractFunctionError
import opt import opt
import env import env
...@@ -8,15 +8,14 @@ import features ...@@ -8,15 +8,14 @@ import features
import ext import ext
__all__ = ['UNCOMPUTED', __all__ = [ 'UNDEFINED',
'UNDEFINED',
'current_mode', 'current_mode',
'set_mode', 'set_mode',
'build_mode', 'build_mode',
'eval_mode', 'eval_mode',
'build_eval_mode', 'build_eval_mode',
'pop_mode', 'pop_mode',
'ResultValue', #'ResultValue',
'DummyOp', 'DummyOp',
'DummyRemover', 'DummyRemover',
'PythonOp', 'PythonOp',
...@@ -24,7 +23,6 @@ __all__ = ['UNCOMPUTED', ...@@ -24,7 +23,6 @@ __all__ = ['UNCOMPUTED',
'make_static'] 'make_static']
UNCOMPUTED = Keyword("UNCOMPUTED", False)
UNDEFINED = Keyword("UNDEFINED", False) UNDEFINED = Keyword("UNDEFINED", False)
def make_static(cls, fname): def make_static(cls, fname):
...@@ -58,11 +56,6 @@ def compute(*nodes): ...@@ -58,11 +56,6 @@ def compute(*nodes):
"""Recursively evaluate each node (in a quick & dirty way).""" """Recursively evaluate each node (in a quick & dirty way)."""
compute_from(nodes, set()) compute_from(nodes, set())
def is_result(obj):
"""Return True iff obj provides the interface of a Result"""
attr_list = 'data', 'owner'
return all([hasattr(obj, attr) for attr in attr_list])
class ForbidConstantOverwrite(features.Listener, features.Constraint): class ForbidConstantOverwrite(features.Listener, features.Constraint):
def __init__(self, env): def __init__(self, env):
...@@ -105,7 +98,8 @@ class ForbidConstantOverwrite(features.Listener, features.Constraint): ...@@ -105,7 +98,8 @@ class ForbidConstantOverwrite(features.Listener, features.Constraint):
class ResultValue(Result): if 0:
class ResultValue(Result):
"""Augment Result to wrap a computed value. """Augment Result to wrap a computed value.
Attributes: Attributes:
...@@ -128,7 +122,7 @@ class ResultValue(Result): ...@@ -128,7 +122,7 @@ class ResultValue(Result):
__slots__ = ['data', 'spec', 'constant', 'up_to_date'] __slots__ = ['data', 'spec', 'constant', 'up_to_date']
def __init__(self, x = UNCOMPUTED, constant = False): def __init__(self, x = None, constant = False):
self.constant = False self.constant = False
self.set_value(x) # allow set_value before constant = True self.set_value(x) # allow set_value before constant = True
self.constant = constant self.constant = constant
...@@ -152,7 +146,7 @@ class ResultValue(Result): ...@@ -152,7 +146,7 @@ class ResultValue(Result):
raise Exception("This Result is a constant. Its value cannot be changed.") raise Exception("This Result is a constant. Its value cannot be changed.")
if value is None or value is UNCOMPUTED: if value is None or value is UNCOMPUTED:
self.data = UNCOMPUTED self.data = UNCOMPUTED
elif isinstance(value, ResultValue): elif is_result(value):
self.set_value(value.data) self.set_value(value.data)
else: else:
try: try:
...@@ -467,7 +461,8 @@ class PythonOp(Op): ...@@ -467,7 +461,8 @@ class PythonOp(Op):
return all([ is_result(i) for i in self.inputs]) return all([ is_result(i) for i in self.inputs])
def gen_outputs(self): def gen_outputs(self):
return [ResultValue() for i in xrange(self.nout)] raise NotImplementedError()
#return [ResultValue() for i in xrange(self.nout)]
def view_map(self): return {} def view_map(self): return {}
...@@ -500,7 +495,7 @@ class PythonOp(Op): ...@@ -500,7 +495,7 @@ class PythonOp(Op):
return answer return answer
def check_input(self, input): def check_input(self, input):
if input.data is UNCOMPUTED: if input.data is None:
raise ValueError("Uncomputed input: %s in %s" % (input, self)) raise ValueError("Uncomputed input: %s in %s" % (input, self))
if not self.input_is_up_to_date(input): if not self.input_is_up_to_date(input):
raise ValueError("Input is out of date: %s in %s" % (input, self)) raise ValueError("Input is out of date: %s in %s" % (input, self))
......
...@@ -10,7 +10,7 @@ from err import GofError ...@@ -10,7 +10,7 @@ from err import GofError
from utils import AbstractFunctionError from utils import AbstractFunctionError
__all__ = ['Result', 'BrokenLink', 'BrokenLinkError'] __all__ = ['is_result', 'Result', 'BrokenLink', 'BrokenLinkError']
class BrokenLink: class BrokenLink:
...@@ -40,6 +40,11 @@ class BrokenLinkError(GofError): ...@@ -40,6 +40,11 @@ class BrokenLinkError(GofError):
# Result # Result
############################ ############################
def is_result(obj):
"""Return True iff obj provides the interface of a Result"""
attr_list = 'owner',
return all([hasattr(obj, attr) for attr in attr_list])
class Result(object): class Result(object):
"""Storage node for data in a graph of Op instances. """Storage node for data in a graph of Op instances.
......
...@@ -59,23 +59,21 @@ class Grad(object): ...@@ -59,23 +59,21 @@ class Grad(object):
""" """
if dr is core.UNDEFINED: if dr is core.UNDEFINED:
# nothing to do # nothing to do
pass return
else:
if r.data is core.UNCOMPUTED or dr.data is core.UNCOMPUTED: if r.data is not None and dr.data is not None:
pass # no sanity checking if not hasattr(r, 'shape'):
else: # some sanity checking to catch obvious mistakes
if not hasattr(r.data, 'shape'):
raise ValueError(('Grad::add r lacks shape: type=', raise ValueError(('Grad::add r lacks shape: type=',
type(r.data))) type(r)))
if not hasattr(dr.data, 'shape'): if not hasattr(dr, 'shape'):
raise ValueError(('Grad::add dr lacks shape: type=', raise ValueError(('Grad::add dr lacks shape: type=',
type(dr.data))) type(dr)))
if r.data.shape != dr.data.shape: if r.shape != dr.shape:
raise ValueError(('Grad::add r, dr shape mismatch', raise ValueError(('Grad::add r, dr shape mismatch',
r.data.shape, dr.data.shape)) r.shape, dr.shape))
# prevent 'r' from being re-calculated by self.__call__ in 'build_eval' mode # prevent 'r' from being re-calculated by self.__call__ in 'build_eval' mode
if r.data is not core.UNCOMPUTED: if r.computed:
self._compute_history.add(r) self._compute_history.add(r)
# add dr to self[r] # add dr to self[r]
...@@ -213,14 +211,14 @@ class _testCase (unittest.TestCase): ...@@ -213,14 +211,14 @@ class _testCase (unittest.TestCase):
diff = wwi - ident diff = wwi - ident
ssdiff = core.sum((diff**2)) ssdiff = core.sum((diff**2))
if i == 0: if i == 0:
str0 = str_ssdiff = str(ssdiff) str0 = str_ssdiff = str(ssdiff.data)
#print ssdiff #print ssdiff
g = grad(ssdiff) g = grad(ssdiff)
gw = g(w) gw = g(w)
w.data += -0.4 * gw.data w.data[:] += -0.4 * gw.data
return str0, str(ssdiff) return str0, str(ssdiff.data)
def matinv_compiled(self, dim): def matinv_compiled(self, dim):
w = core.wrap(numpy.random.rand(dim,dim)) w = core.wrap(numpy.random.rand(dim,dim))
...@@ -230,7 +228,7 @@ class _testCase (unittest.TestCase): ...@@ -230,7 +228,7 @@ class _testCase (unittest.TestCase):
wwi = core.dot(w, wi) wwi = core.dot(w, wi)
diff = wwi - ident diff = wwi - ident
ssdiff = core.sum((diff**2)) ssdiff = core.sum((diff**2))
str0 = str_ssdiff = str(ssdiff) str0 = str_ssdiff = str(ssdiff.data)
#print ssdiff #print ssdiff
g = grad(ssdiff) g = grad(ssdiff)
...@@ -240,9 +238,9 @@ class _testCase (unittest.TestCase): ...@@ -240,9 +238,9 @@ class _testCase (unittest.TestCase):
for i in xrange(300): for i in xrange(300):
prog() prog()
w.data += -0.4 * gw.data w.data[:] += -0.4 * gw.data
return str0, str(ssdiff) return str0, str(ssdiff.data)
def test0(self): def test0(self):
"""Matrix inversion by gradient descent (eval mode)""" """Matrix inversion by gradient descent (eval mode)"""
......
...@@ -9,7 +9,7 @@ import grad ...@@ -9,7 +9,7 @@ import grad
# Wrapper type # Wrapper type
class SparseR(gof.ResultValue): class SparseR(core.ResultBase):
""" """
Attribute: Attribute:
format - a subclass of sparse.spmatrix indicating self.data.__class__ format - a subclass of sparse.spmatrix indicating self.data.__class__
...@@ -22,12 +22,17 @@ class SparseR(gof.ResultValue): ...@@ -22,12 +22,17 @@ class SparseR(gof.ResultValue):
Notes: Notes:
""" """
def __init__(self, x = core.UNCOMPUTED, constant = False, def __init__(self, data=None, role=None, constant = False,
format = sparse.csr_matrix): format = sparse.csr_matrix):
gof.ResultValue.__init__(self, x, constant) core.ResultBase.__init__(self, role, data, constant)
self.format = isinstance(x, sparse.spmatrix) and x.__class__ or format if isinstance(data, sparse.spmatrix):
self.format = data.__class__
else:
self.format = format
self._dtype = None
self._shape = None
def set_value_filter(self, value): def data_filter(self, value):
if isinstance(value, sparse.spmatrix): return value if isinstance(value, sparse.spmatrix): return value
return sparse.csr_matrix(value) return sparse.csr_matrix(value)
...@@ -36,6 +41,32 @@ class SparseR(gof.ResultValue): ...@@ -36,6 +41,32 @@ class SparseR(gof.ResultValue):
T = property(lambda self: transpose(self), doc = "Return aliased transpose") T = property(lambda self: transpose(self), doc = "Return aliased transpose")
# self._dtype is used when self._data hasn't been set yet
def __dtype_get(self):
if self._data is None:
return self._dtype
else:
return self._data.dtype
def __dtype_set(self, dtype):
if self._data is None:
self._dtype = dtype
else:
raise StateError('cannot set dtype after data has been set')
dtype = property(__dtype_get, __dtype_set)
# self._shape is used when self._data hasn't been set yet
def __shape_get(self):
if self._data is None:
return self._shape
else:
return self._data.shape
def __shape_set(self, shape):
if self._data is None:
self._shape = shape
else:
raise StateError('cannot set shape after data has been set')
shape = property(__shape_get, __shape_set)
# convenience base class # convenience base class
class op(gof.PythonOp, grad.update_gradient_via_grad): class op(gof.PythonOp, grad.update_gradient_via_grad):
pass pass
...@@ -46,7 +77,7 @@ class op(gof.PythonOp, grad.update_gradient_via_grad): ...@@ -46,7 +77,7 @@ class op(gof.PythonOp, grad.update_gradient_via_grad):
# convert a sparse matrix to an ndarray # convert a sparse matrix to an ndarray
class sparse2dense(op): class sparse2dense(op):
def gen_outputs(self): return [core.NumpyR()] def gen_outputs(self): return [core.Numpy2()]
def impl(x): return numpy.asarray(x.todense()) def impl(x): return numpy.asarray(x.todense())
def grad(self, x, gz): def grad(self, x, gz):
if x.format is sparse.coo_matrix: return dense2coo(gz) if x.format is sparse.coo_matrix: return dense2coo(gz)
...@@ -135,7 +166,7 @@ class _testCase_dot(unittest.TestCase): ...@@ -135,7 +166,7 @@ class _testCase_dot(unittest.TestCase):
def test_basic0(self): def test_basic0(self):
for mtype in [sparse.csc_matrix, sparse.csr_matrix]: for mtype in [sparse.csc_matrix, sparse.csr_matrix]:
x = SparseR(mtype(sparse.speye(5,3))) x = SparseR(mtype(sparse.speye(5,3)))
y = core.NumpyR(numpy.random.rand(3, 2)) y = core.wrap(numpy.random.rand(3, 2))
z = dot(x,y) z = dot(x,y)
self.failUnless(z.data.shape == (5,2)) self.failUnless(z.data.shape == (5,2))
...@@ -171,7 +202,7 @@ class _testCase_dot(unittest.TestCase): ...@@ -171,7 +202,7 @@ class _testCase_dot(unittest.TestCase):
self.failUnless(z.data.shape == ab.shape) self.failUnless(z.data.shape == ab.shape)
self.failUnless(type(z.data) == type(ab)) self.failUnless(type(z.data) == type(ab))
def test_graph_bprop0(self): def test_graph_bprop0(self):
x = core.NumpyR(numpy.random.rand(10,2)) x = core.wrap(numpy.random.rand(10,2))
w = SparseR(sparse.csr_matrix(numpy.asarray([[1, 0, 3, 0, 5], [0, 0, -2, 0, w = SparseR(sparse.csr_matrix(numpy.asarray([[1, 0, 3, 0, 5], [0, 0, -2, 0,
0]],dtype='float64'))) 0]],dtype='float64')))
...@@ -187,10 +218,10 @@ class _testCase_dot(unittest.TestCase): ...@@ -187,10 +218,10 @@ class _testCase_dot(unittest.TestCase):
g(w).data[1,4] = 0 g(w).data[1,4] = 0
w.data = -lr * g(w).data + w.data w.data = -lr * g(w).data + w.data
self.failUnless('3.08560636025' == str(loss)) self.failUnless('3.08560636025' == str(loss.data))
def test_graph_bprop1(self): def test_graph_bprop1(self):
x = core.NumpyR(numpy.random.rand(10,2)) x = core.wrap(numpy.random.rand(10,2))
w = SparseR(sparse.csr_matrix(numpy.asarray([[1, 0, 3, 0, 5], [0, 0, -2, 0, w = SparseR(sparse.csr_matrix(numpy.asarray([[1, 0, 3, 0, 5], [0, 0, -2, 0,
0]],dtype='float64'))) 0]],dtype='float64')))
...@@ -205,7 +236,7 @@ class _testCase_dot(unittest.TestCase): ...@@ -205,7 +236,7 @@ class _testCase_dot(unittest.TestCase):
g(w).data[1,4] = 0 g(w).data[1,4] = 0
w.data = -lr * g(w).data + w.data w.data = -lr * g(w).data + w.data
self.failUnless('3.08560636025' == str(loss)) self.failUnless('3.08560636025' == str(loss.data))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论