broken

上级 460f6b78
import time import time, unittest
import numpy
import gof import gof
import gof.lib import gof.lib
...@@ -170,13 +171,13 @@ class prog(gof.Prog): ...@@ -170,13 +171,13 @@ class prog(gof.Prog):
""" """
if check_uncomputed: if check_uncomputed:
for input in self.env.inputs: for input in self.env.inputs:
if input.data is core.UNCOMPUTED: if input.data is None:
raise Exception("You must provide a value for input %s!" % input) raise Exception("You must provide a value for input %s!" % input)
return gof.Prog.__call__(self) return gof.Prog.__call__(self)
def compute_orphans(self): def compute_orphans(self):
for orphan in self.env.orphans(): for orphan in self.env.orphans():
if orphan.data is core.UNCOMPUTED: if orphan.data is None:
if orphan.owner: if orphan.owner:
gof.lib.compute(orphan.owner) gof.lib.compute(orphan.owner)
else: else:
...@@ -200,3 +201,22 @@ def to_func(inputs, outputs): ...@@ -200,3 +201,22 @@ def to_func(inputs, outputs):
def single(*outputs, **kwargs): def single(*outputs, **kwargs):
return prog(gof.graph.inputs(outputs), outputs, **kwargs) return prog(gof.graph.inputs(outputs), outputs, **kwargs)
class _test_single(unittest.TestCase):
def setUp(self):
core.build_eval_mode()
numpy.random.seed(44)
def tearDown(self):
core.pop_mode()
def test_3(self):
a = core.Numpy2(data=numpy.ones((2,2)))
b = core.Numpy2(data=numpy.ones((2,2)))
c = core.add(a,b)
p = single(c)
p()
self.failUnless(core._approx_eq(c, numpy.ones((2,2))*2))
if __name__ == '__main__':
unittest.main()
...@@ -12,7 +12,7 @@ from scipy import weave ...@@ -12,7 +12,7 @@ from scipy import weave
import gof import gof
from gof import current_mode, set_mode, build_mode, eval_mode, build_eval_mode from gof import current_mode, set_mode, build_mode, eval_mode, build_eval_mode
from gof import pop_mode, UNCOMPUTED, UNDEFINED, ResultValue from gof import pop_mode, UNDEFINED, is_result
import type_spec import type_spec
import cutils import cutils
...@@ -45,7 +45,7 @@ def _approx_eq(a,b,eps=1.0e-9): ...@@ -45,7 +45,7 @@ def _approx_eq(a,b,eps=1.0e-9):
b = numpy.asarray(b) b = numpy.asarray(b)
if a.shape != b.shape: if a.shape != b.shape:
return False return False
return numpy.max( numpy.abs(a-b)) < eps return numpy.max(numpy.abs(a-b)) < eps
# This function is only executed the first time it is called, subsequent calls # This function is only executed the first time it is called, subsequent calls
...@@ -84,7 +84,7 @@ def _compile_dir(): ...@@ -84,7 +84,7 @@ def _compile_dir():
sys.path.append(cachedir) sys.path.append(cachedir)
return cachedir return cachedir
class ResultBase: class ResultBase(object):
"""Base class for storing Op inputs and outputs """Base class for storing Op inputs and outputs
Attributes: Attributes:
...@@ -98,6 +98,7 @@ class ResultBase: ...@@ -98,6 +98,7 @@ class ResultBase:
index - (ro) index - (ro)
data - (rw) data - (rw)
replaced - (rw) : True iff _role is BrokenLink replaced - (rw) : True iff _role is BrokenLink
computed - (ro) : True iff contents of data are fresh
Abstract Methods: Abstract Methods:
data_filter data_filter
...@@ -170,11 +171,12 @@ class ResultBase: ...@@ -170,11 +171,12 @@ class ResultBase:
if self.replaced: raise ResultBase.BrokenLinkError() if self.replaced: raise ResultBase.BrokenLinkError()
if self.constant: raise Exception('cannot set constant ResultBase') if self.constant: raise Exception('cannot set constant ResultBase')
try: try:
self._data = self.data_filter(self, data) self._data = self.data_filter(data)
except ResultBase.AbstractFunction: except ResultBase.AbstractFunction: #use default behaviour
self._data = data self._data = data
data = property(__get_data, __set_data, data = property(__get_data, __set_data,
doc = "The storage associated with this result") doc = "The storage associated with this result")
def data_filter(self, data): def data_filter(self, data):
"""(abstract) Return an appropriate _data based on data.""" """(abstract) Return an appropriate _data based on data."""
raise ResultBase.AbstractFunction() raise ResultBase.AbstractFunction()
...@@ -189,7 +191,28 @@ class ResultBase: ...@@ -189,7 +191,28 @@ class ResultBase:
else: else:
self._role = self._role.old_role self._role = self._role.old_role
replaced = property(__get_replaced, __set_replaced, doc = "has this Result been replaced?") replaced = property(__get_replaced, __set_replaced, doc = "has this Result been replaced?")
# computed
#TODO: think about how to handle this more correctly
computed = property(lambda self: self._data is not None)
#################
# NumpyR Compatibility
#
up_to_date = property(lambda self: True)
def refresh(self): pass
def set_owner(self, owner, idx):
self.role = (owner, idx)
def set_value(self, value):
self.data = value #may raise exception
class _test_ResultBase(unittest.TestCase): class _test_ResultBase(unittest.TestCase):
def setUp(self):
build_eval_mode()
numpy.random.seed(44)
def tearDown(self):
pop_mode()
def test_0(self): def test_0(self):
r = ResultBase() r = ResultBase()
...@@ -213,10 +236,13 @@ class Numpy2(ResultBase): ...@@ -213,10 +236,13 @@ class Numpy2(ResultBase):
# ResultBase # ResultBase
# #
def data_filter(self, data): def data_filter(self, data):
#return numpy.asarray(data) #TODO: consider whether this is correct #TODO: decide which of these implementations is better
if isinstance(data, numpy.ndarray): if 0:
return data if isinstance(data, numpy.ndarray):
raise TypeError('failed to filter data to ndarray', data) return data
raise TypeError('failed to filter data to ndarray', data)
else:
return numpy.asarray(data)
################################ ################################
...@@ -256,7 +282,53 @@ class Numpy2(ResultBase): ...@@ -256,7 +282,53 @@ class Numpy2(ResultBase):
else: else:
raise StateError('cannot set shape after data has been set') raise StateError('cannot set shape after data has been set')
shape = property(__shape_get, __shape_set) shape = property(__shape_get, __shape_set)
def __add__(self, y): return add(self, y)
def __radd__(self, x): return add(x, self)
def __iadd__(self, y): return add_inplace(self, y)
def __sub__(self, y): return sub(self, y)
def __rsub__(self, x): return sub(x, self)
def __isub__(self, y): return sub_inplace(self, y)
def __mul__(self, y): return mul(self, y)
def __rmul__(self, x): return mul(x, self)
def __imul__(self, y): return mul_inplace(self, y)
def __div__(self, y): return div(self, y)
def __rdiv__(self, x): return div(x, self)
def __idiv__(self, y): return div_inplace(self, y)
def __pow__(self, y): return pow(self, y)
def __rpow__(self, x): return pow(x, self)
def __ipow__(self, y): return pow_inplace(self, y)
def __neg__(self): return neg(self)
T = property(lambda self: transpose(self))
Tc = property(lambda self: transpose_copy(self))
def __copy__(self): return array_copy(self)
def __getitem__(self, item): return get_slice(self, item)
def __getslice__(self, *args): return get_slice(self, slice(*args))
#################
# NumpyR Compatibility
#
spec = property(lambda self: (numpy.ndarray, self.dtype, self.shape))
def set_value_inplace(self, value):
if 0 == len(self.shape):
self.data.itemset(value) # for scalars
else:
self.data[:] = value # for matrices
class _test_Numpy2(unittest.TestCase): class _test_Numpy2(unittest.TestCase):
def setUp(self):
build_eval_mode()
numpy.random.seed(44)
def tearDown(self):
pop_mode()
def test_0(self): def test_0(self):
r = Numpy2() r = Numpy2()
def test_1(self): def test_1(self):
...@@ -265,6 +337,7 @@ class _test_Numpy2(unittest.TestCase): ...@@ -265,6 +337,7 @@ class _test_Numpy2(unittest.TestCase):
self.failUnless(r.data is o) self.failUnless(r.data is o)
self.failUnless(r.shape == (3,3)) self.failUnless(r.shape == (3,3))
self.failUnless(str(r.dtype) == 'float64') self.failUnless(str(r.dtype) == 'float64')
def test_2(self): def test_2(self):
r = Numpy2(data=[(3,3),'int32']) r = Numpy2(data=[(3,3),'int32'])
self.failUnless(r.data is None) self.failUnless(r.data is None)
...@@ -275,27 +348,45 @@ class _test_Numpy2(unittest.TestCase): ...@@ -275,27 +348,45 @@ class _test_Numpy2(unittest.TestCase):
self.failUnless(r.shape == (3,3)) self.failUnless(r.shape == (3,3))
self.failUnless(str(r.dtype) == 'int32') self.failUnless(str(r.dtype) == 'int32')
def test_3(self):
a = Numpy2(data=numpy.ones((2,2)))
b = Numpy2(data=numpy.ones((2,2)))
c = add(a,b)
self.failUnless(_approx_eq(c, numpy.ones((2,2))*2))
def test_4(self):
ones = numpy.ones((2,2))
a = Numpy2(data=ones)
o = numpy.asarray(a)
self.failUnless((ones == o).all())
def test_5(self):
ones = numpy.ones((2,2))
self.failUnless(_approx_eq(Numpy2(data=ones), Numpy2(data=ones)))
def input(x): def input(x):
#static member initialization #static member initialization
if not hasattr(input, 'float_dtype'): if not hasattr(input, 'float_dtype'):
input.float_dtype = 'float64' input.float_dtype = 'float64'
input.int_dtype = 'int64' input.int_dtype = 'int64'
input.NN = NumpyR input.NN = Numpy2
if isinstance(x, numpy.ndarray): if isinstance(x, numpy.ndarray):
return input.NN(x) #return NumpyR(x)
return input.NN(data=x)
elif isinstance(x, int): elif isinstance(x, int):
z = numpy.zeros((), dtype = input.int_dtype) z = numpy.zeros((), dtype = input.int_dtype)
z += x z += x
return input.NN(z) return input.NN(data=z)
elif isinstance(x, float): elif isinstance(x, float):
z = numpy.zeros((), dtype = input.float_dtype) z = numpy.zeros((), dtype = input.float_dtype)
z += x z += x
return input.NN(z) return input.NN(data=z)
elif isinstance(x, gof.Result): elif is_result(x):
raise TypeError("%s is already a result." % x) raise TypeError("%s is already a result." % x)
else: else:
return ResultValue(x) return ResultBase(x)
class _testCase_input(unittest.TestCase): class _testCase_input(unittest.TestCase):
def setUp(self): def setUp(self):
literal.hdb = {} literal.hdb = {}
...@@ -312,11 +403,11 @@ class _testCase_input(unittest.TestCase): ...@@ -312,11 +403,11 @@ class _testCase_input(unittest.TestCase):
self.failUnless(w.data == 3.0) self.failUnless(w.data == 3.0)
def wrap(x): def wrap(x):
if isinstance(x, NumpyR): if isinstance(x, Numpy2):
return x return x
elif isinstance(x, Numpy2): #elif isinstance(x, NumpyR):
return x #return x
elif isinstance(x, ResultValue): elif is_result(x):
return x return x
elif isinstance(x, omega_op): elif isinstance(x, omega_op):
return x.out return x.out
...@@ -482,7 +573,7 @@ class omega_op(gof.PythonOp): ...@@ -482,7 +573,7 @@ class omega_op(gof.PythonOp):
return gof.PythonOp.__new__(cls, *inputs) return gof.PythonOp.__new__(cls, *inputs)
def gen_outputs(self): def gen_outputs(self):
return [NumpyR() for i in xrange(self.nout)] return [Numpy2() for i in xrange(self.nout)]
def update_gradient(self, grad_d): def update_gradient(self, grad_d):
"""Call self.grad() and add the result to grad_d """Call self.grad() and add the result to grad_d
...@@ -504,7 +595,7 @@ class omega_op(gof.PythonOp): ...@@ -504,7 +595,7 @@ class omega_op(gof.PythonOp):
""" """
inputgs = self.grad(*(self.inputs + [grad_d[output] for output in self.outputs])) inputgs = self.grad(*(self.inputs + [grad_d[output] for output in self.outputs]))
if len(self.inputs) == 1 and isinstance(inputgs, gof.ResultValue): if len(self.inputs) == 1 and gof.result.is_result(inputgs):
inputgs = [inputgs] inputgs = [inputgs]
else: else:
assert len(inputgs) == len(self.inputs) assert len(inputgs) == len(self.inputs)
...@@ -535,6 +626,11 @@ class omega_op(gof.PythonOp): ...@@ -535,6 +626,11 @@ class omega_op(gof.PythonOp):
def c_impl(inputs, outputs): def c_impl(inputs, outputs):
raise NotImplementedError() raise NotImplementedError()
def c_compile_args(self):
# I always used these, but they don't make much improvement
#'-ffast-math', '-falign-loops=8'
return ['-O2']
def c_thunk_factory(self): def c_thunk_factory(self):
self.refresh() self.refresh()
d, names, code, struct, converters = self.c_code() d, names, code, struct, converters = self.c_code()
...@@ -549,10 +645,8 @@ class omega_op(gof.PythonOp): ...@@ -549,10 +645,8 @@ class omega_op(gof.PythonOp):
global_dict = {}, global_dict = {},
type_converters = converters) type_converters = converters)
instantiate.customize.add_support_code(self.c_support_code() + struct) instantiate.customize.add_support_code(self.c_support_code() + struct)
instantiate.customize.add_extra_compile_arg("-O3") for arg in self.c_compile_args():
instantiate.customize.add_extra_compile_arg("-ffast-math") #TODO: make this optional, say by passing args to c_thunk_factory? instantiate.customize.add_extra_compile_arg(arg)
instantiate.customize.add_extra_compile_arg("-falign-loops=4")
# instantiate.customize.add_extra_compile_arg("-mfpmath=sse")
for header in self.c_headers(): for header in self.c_headers():
instantiate.customize.add_header(header) instantiate.customize.add_header(header)
for lib in self.c_libs(): for lib in self.c_libs():
...@@ -725,7 +819,7 @@ class elemwise(omega_op): ...@@ -725,7 +819,7 @@ class elemwise(omega_op):
try: try:
dtype = upcast(*[input.spec[1] dtype = upcast(*[input.spec[1]
for iname, input in zip(inames, self.inputs) for iname, input in zip(inames, self.inputs)
if isinstance(input, NumpyR)]) if input.spec[0] is numpy.ndarray])
except IndexError: except IndexError:
raise Exception("not all numpy inputs are specified") raise Exception("not all numpy inputs are specified")
...@@ -895,78 +989,79 @@ def scalar_switch(normal_f, scalar_f, scalar_f_reverse = None): ...@@ -895,78 +989,79 @@ def scalar_switch(normal_f, scalar_f, scalar_f_reverse = None):
return normal_f(x, y) return normal_f(x, y)
return f return f
class NumpyR(gof.ResultValue): if 0:
"""The class for storing ndarray return values from omega ops. class NumpyR(gof.ResultValue):
The class provides additional functionality compared to the normal """The class for storing ndarray return values from omega ops.
ResultValue: The class provides additional functionality compared to the normal
- operator overloads that correspond to omega ops such as add() and scale() ResultValue:
- special attributes that make it behave like an ndarray when passed to - operator overloads that correspond to omega ops such as add() and scale()
numpy functions. - special attributes that make it behave like an ndarray when passed to
numpy functions.
Attributes:
__array__ - alias of self.data.__array_struct__ Attributes:
__array_struct__ - alias of self.data.__array_struct__ __array__ - alias of self.data.__array_struct__
__array_struct__ - alias of self.data.__array_struct__
Methods:
set_value() - Methods:
""" set_value() -
"""
# The following attributes make NumpyR instances look like normal ndarray # The following attributes make NumpyR instances look like normal ndarray
# instances to many numpy functions, such as argmax(), dot(), svd(), sum(), # instances to many numpy functions, such as argmax(), dot(), svd(), sum(),
# etc. These are documented in the numpy book. # etc. These are documented in the numpy book.
__array__ = property(lambda self: self.data.__array__ ) __array__ = property(lambda self: self.data.__array__ )
__array_struct__ = property(lambda self: self.data.__array_struct__ ) __array_struct__ = property(lambda self: self.data.__array_struct__ )
def set_value_filter(self, value): return numpy.asarray(value) def set_value_filter(self, value): return numpy.asarray(value)
def set_value_inplace(self, value): def set_value_inplace(self, value):
if value is UNCOMPUTED: if value is UNCOMPUTED:
raise ValueError() raise ValueError()
else:
if 0 == len(self.data.shape):
self.data.itemset(value) # for scalars
else: else:
self.data[:] = value # for matrices if 0 == len(self.data.shape):
self.refresh() self.data.itemset(value) # for scalars
self.up_to_date = True else:
self.data[:] = value # for matrices
def refresh(self): self.refresh()
if self.data is not UNCOMPUTED: self.up_to_date = True
self.spec = (numpy.ndarray, self.data.dtype, self.data.shape)
def refresh(self):
if self.data is not UNCOMPUTED:
self.spec = (numpy.ndarray, self.data.dtype, self.data.shape)
def alloc(self):
shape, dtype = self.spec[2], self.spec[1]
self.data = numpy.ndarray(shape, dtype=dtype)
def __add__(self, y): return add(self, y)
def __radd__(self, x): return add(x, self)
def __iadd__(self, y): return add_inplace(self, y)
def alloc(self): def __sub__(self, y): return sub(self, y)
shape, dtype = self.spec[2], self.spec[1] def __rsub__(self, x): return sub(x, self)
self.data = numpy.ndarray(shape, dtype=dtype) def __isub__(self, y): return sub_inplace(self, y)
def __add__(self, y): return add(self, y)
def __radd__(self, x): return add(x, self)
def __iadd__(self, y): return add_inplace(self, y)
def __sub__(self, y): return sub(self, y)
def __rsub__(self, x): return sub(x, self)
def __isub__(self, y): return sub_inplace(self, y)
def __mul__(self, y): return mul(self, y)
def __rmul__(self, x): return mul(x, self)
def __imul__(self, y): return mul_inplace(self, y)
def __div__(self, y): return div(self, y)
def __rdiv__(self, x): return div(x, self)
def __idiv__(self, y): return div_inplace(self, y)
def __pow__(self, y): return pow(self, y) def __mul__(self, y): return mul(self, y)
def __rpow__(self, x): return pow(x, self) def __rmul__(self, x): return mul(x, self)
def __ipow__(self, y): return pow_inplace(self, y) def __imul__(self, y): return mul_inplace(self, y)
def __div__(self, y): return div(self, y)
def __rdiv__(self, x): return div(x, self)
def __idiv__(self, y): return div_inplace(self, y)
def __pow__(self, y): return pow(self, y)
def __rpow__(self, x): return pow(x, self)
def __ipow__(self, y): return pow_inplace(self, y)
def __neg__(self): return neg(self) def __neg__(self): return neg(self)
T = property(lambda self: transpose(self)) T = property(lambda self: transpose(self))
Tc = property(lambda self: transpose_copy(self)) Tc = property(lambda self: transpose_copy(self))
def __copy__(self): return array_copy(self) def __copy__(self): return array_copy(self)
def __getitem__(self, item): return get_slice(self, item) def __getitem__(self, item): return get_slice(self, item)
def __getslice__(self, *args): return get_slice(self, slice(*args)) def __getslice__(self, *args): return get_slice(self, slice(*args))
...@@ -1215,157 +1310,160 @@ class dot(omega_op): ...@@ -1215,157 +1310,160 @@ class dot(omega_op):
def c_impl((_x, _y), (_z, )): def c_impl((_x, _y), (_z, )):
return blas.gemm_code('', '1.0', '0.0') return blas.gemm_code('', '1.0', '0.0')
class _testCase_dot(unittest.TestCase): if 0:
def setUp(self): print 'SKIPPING DOT TESTS'
build_eval_mode() else:
numpy.random.seed(44) class _testCase_dot(unittest.TestCase):
def tearDown(self): def setUp(self):
pop_mode() build_eval_mode()
numpy.random.seed(44)
@staticmethod def tearDown(self):
def rand(*args): pop_mode()
return numpy.random.rand(*args)
@staticmethod
def cmp_dot(self,x,y): def rand(*args):
def spec(x): return numpy.random.rand(*args)
def cmp_dot(self,x,y):
def spec(x):
x = numpy.asarray(x)
return type(x), x.dtype, x.shape
zspec = dot.specs(spec(x), spec(y))
nz = numpy.dot(x,y)
self.failUnless(zspec == spec(nz))
self.failUnless(_approx_eq(dot(x,y), numpy.dot(x,y)))
def cmp_dot_comp(self, x,y):
x = numpy.asarray(x) x = numpy.asarray(x)
return type(x), x.dtype, x.shape y = numpy.asarray(y)
zspec = dot.specs(spec(x), spec(y))
nz = numpy.dot(x,y)
self.failUnless(zspec == spec(nz))
self.failUnless(_approx_eq(dot(x,y), numpy.dot(x,y)))
def cmp_dot_comp(self, x,y):
x = numpy.asarray(x)
y = numpy.asarray(y)
z = dot(x,y)
p = compile.single(z)
if len(x.shape):
x[:] = numpy.random.rand(*x.shape)
else:
x.fill(numpy.random.rand(*x.shape))
if len(y.shape):
y[:] = numpy.random.rand(*y.shape)
else:
y.fill(numpy.random.rand(*y.shape))
p() # recalculate z
self.failUnless(_approx_eq(z, numpy.dot(x,y)))
def test_dot_0d_0d(self): self.cmp_dot(1.1, 2.2)
def test_dot_0d_1d(self): self.cmp_dot(1.1, self.rand(5))
def test_dot_0d_2d(self): self.cmp_dot(3.0, self.rand(6,7))
def test_dot_0d_3d(self): self.cmp_dot(3.0, self.rand(8,6,7))
def test_dot_1d_0d(self): self.cmp_dot(self.rand(5), 1.1 )
def test_dot_1d_1d(self): self.cmp_dot(self.rand(5), self.rand(5))
def test_dot_1d_2d(self): self.cmp_dot(self.rand(6), self.rand(6,7))
def test_dot_1d_3d(self): self.cmp_dot(self.rand(6), self.rand(8,6,7))
def test_dot_2d_0d(self): self.cmp_dot(self.rand(5,6), 1.0)
def test_dot_2d_1d(self): self.cmp_dot(self.rand(5,6), self.rand(6))
def test_dot_2d_2d(self): self.cmp_dot(self.rand(5,6), self.rand(6,7))
def test_dot_2d_3d(self): self.cmp_dot(self.rand(5,6), self.rand(8,6,7))
def test_dot_3d_0d(self): self.cmp_dot(self.rand(4,5,6), 1.0)
def test_dot_3d_1d(self): self.cmp_dot(self.rand(4,5,6), self.rand(6))
def test_dot_3d_2d(self): self.cmp_dot(self.rand(4,5,6), self.rand(6,7))
def test_dot_3d_3d(self): self.cmp_dot(self.rand(4,5,6), self.rand(8,6,7))
def test_dot_0d_0d_(self): self.cmp_dot_comp(1.1, 2.2)
def test_dot_0d_1d_(self): self.cmp_dot_comp(1.1, self.rand(5))
def test_dot_0d_2d_(self): self.cmp_dot_comp(3.0, self.rand(6,7))
def test_dot_0d_3d_(self): self.cmp_dot_comp(3.0, self.rand(8,6,7))
def test_dot_1d_0d_(self): self.cmp_dot_comp(self.rand(5), 1.1 )
def test_dot_1d_1d_(self): self.cmp_dot_comp(self.rand(5), self.rand(5))
def test_dot_1d_2d_(self): self.cmp_dot_comp(self.rand(6), self.rand(6,7))
def test_dot_1d_3d_(self): self.cmp_dot_comp(self.rand(6), self.rand(8,6,7))
def test_dot_2d_0d_(self): self.cmp_dot_comp(self.rand(5,6), 1.0)
def test_dot_2d_1d_(self): self.cmp_dot_comp(self.rand(5,6), self.rand(6))
def test_dot_2d_2d_(self): self.cmp_dot_comp(self.rand(5,6), self.rand(6,7))
def test_dot_2d_3d_(self): self.cmp_dot_comp(self.rand(5,6), self.rand(8,6,7))
def test_dot_3d_0d_(self): self.cmp_dot_comp(self.rand(4,5,6), 1.0)
def test_dot_3d_1d_(self): self.cmp_dot_comp(self.rand(4,5,6), self.rand(6))
def test_dot_3d_2d_(self): self.cmp_dot_comp(self.rand(4,5,6), self.rand(6,7))
def test_dot_3d_3d_(self): self.cmp_dot_comp(self.rand(4,5,6), self.rand(8,6,7))
def test_dot_fail_1_1(self):
x = numpy.random.rand(5)
y = numpy.random.rand(6)
try:
z = dot(x,y)
except ValueError, e:
self.failUnless(str(e) == 'objects are not aligned', e)
return
self.fail()
def test_dot_fail_1_2(self):
x = numpy.random.rand(5)
y = numpy.random.rand(6,4)
try:
z = dot(x,y)
except ValueError, e:
self.failUnless(str(e) == 'objects are not aligned', e)
return
self.fail()
def test_dot_fail_1_3(self):
x = numpy.random.rand(5)
y = numpy.random.rand(6,4,7)
try:
z = dot(x,y)
except ValueError, e:
self.failUnless(str(e) == 'objects are not aligned', e)
return
self.fail()
def test_dot_fail_2_1(self):
x = numpy.random.rand(5,4)
y = numpy.random.rand(6)
try:
z = dot(x,y) z = dot(x,y)
except ValueError, e: p = compile.single(z)
self.failUnless(str(e) == 'objects are not aligned', e) if len(x.shape):
return x[:] = numpy.random.rand(*x.shape)
self.fail() else:
def test_dot_fail_2_2(self): x.fill(numpy.random.rand(*x.shape))
x = numpy.random.rand(5,4) if len(y.shape):
y = numpy.random.rand(6,7) y[:] = numpy.random.rand(*y.shape)
try: else:
z = dot(x,y) y.fill(numpy.random.rand(*y.shape))
except ValueError, e: p() # recalculate z
self.failUnless(str(e) == 'objects are not aligned', e) self.failUnless(_approx_eq(z, numpy.dot(x,y)))
return
self.fail() def test_dot_0d_0d(self): self.cmp_dot(1.1, 2.2)
def test_dot_fail_2_3(self): def test_dot_0d_1d(self): self.cmp_dot(1.1, self.rand(5))
x = numpy.random.rand(5,4) def test_dot_0d_2d(self): self.cmp_dot(3.0, self.rand(6,7))
y = numpy.random.rand(6,7,8) def test_dot_0d_3d(self): self.cmp_dot(3.0, self.rand(8,6,7))
try: def test_dot_1d_0d(self): self.cmp_dot(self.rand(5), 1.1 )
z = dot(x,y) def test_dot_1d_1d(self): self.cmp_dot(self.rand(5), self.rand(5))
except ValueError, e: def test_dot_1d_2d(self): self.cmp_dot(self.rand(6), self.rand(6,7))
self.failUnless(str(e) == 'objects are not aligned', e) def test_dot_1d_3d(self): self.cmp_dot(self.rand(6), self.rand(8,6,7))
return def test_dot_2d_0d(self): self.cmp_dot(self.rand(5,6), 1.0)
self.fail() def test_dot_2d_1d(self): self.cmp_dot(self.rand(5,6), self.rand(6))
def test_dot_fail_3_1(self): def test_dot_2d_2d(self): self.cmp_dot(self.rand(5,6), self.rand(6,7))
x = numpy.random.rand(5,4,3) def test_dot_2d_3d(self): self.cmp_dot(self.rand(5,6), self.rand(8,6,7))
y = numpy.random.rand(6) def test_dot_3d_0d(self): self.cmp_dot(self.rand(4,5,6), 1.0)
try: def test_dot_3d_1d(self): self.cmp_dot(self.rand(4,5,6), self.rand(6))
z = dot(x,y) def test_dot_3d_2d(self): self.cmp_dot(self.rand(4,5,6), self.rand(6,7))
except ValueError, e: def test_dot_3d_3d(self): self.cmp_dot(self.rand(4,5,6), self.rand(8,6,7))
self.failUnless(str(e) == 'objects are not aligned', e) def test_dot_0d_0d_(self): self.cmp_dot_comp(1.1, 2.2)
return def test_dot_0d_1d_(self): self.cmp_dot_comp(1.1, self.rand(5))
self.fail() def test_dot_0d_2d_(self): self.cmp_dot_comp(3.0, self.rand(6,7))
def test_dot_fail_3_2(self): def test_dot_0d_3d_(self): self.cmp_dot_comp(3.0, self.rand(8,6,7))
x = numpy.random.rand(5,4,3) def test_dot_1d_0d_(self): self.cmp_dot_comp(self.rand(5), 1.1 )
y = numpy.random.rand(6,7) def test_dot_1d_1d_(self): self.cmp_dot_comp(self.rand(5), self.rand(5))
try: def test_dot_1d_2d_(self): self.cmp_dot_comp(self.rand(6), self.rand(6,7))
z = dot(x,y) def test_dot_1d_3d_(self): self.cmp_dot_comp(self.rand(6), self.rand(8,6,7))
except ValueError, e: def test_dot_2d_0d_(self): self.cmp_dot_comp(self.rand(5,6), 1.0)
self.failUnless(str(e) == 'objects are not aligned', e) def test_dot_2d_1d_(self): self.cmp_dot_comp(self.rand(5,6), self.rand(6))
return def test_dot_2d_2d_(self): self.cmp_dot_comp(self.rand(5,6), self.rand(6,7))
self.fail() def test_dot_2d_3d_(self): self.cmp_dot_comp(self.rand(5,6), self.rand(8,6,7))
def test_dot_fail_3_3(self): def test_dot_3d_0d_(self): self.cmp_dot_comp(self.rand(4,5,6), 1.0)
x = numpy.random.rand(5,4,3) def test_dot_3d_1d_(self): self.cmp_dot_comp(self.rand(4,5,6), self.rand(6))
y = numpy.random.rand(6,7,8) def test_dot_3d_2d_(self): self.cmp_dot_comp(self.rand(4,5,6), self.rand(6,7))
try: def test_dot_3d_3d_(self): self.cmp_dot_comp(self.rand(4,5,6), self.rand(8,6,7))
z = dot(x,y)
except ValueError, e: def test_dot_fail_1_1(self):
self.failUnless(str(e) == 'objects are not aligned', e) x = numpy.random.rand(5)
return y = numpy.random.rand(6)
self.fail() try:
z = dot(x,y)
except ValueError, e:
self.failUnless(str(e) == 'objects are not aligned', e)
return
self.fail()
def test_dot_fail_1_2(self):
x = numpy.random.rand(5)
y = numpy.random.rand(6,4)
try:
z = dot(x,y)
except ValueError, e:
self.failUnless(str(e) == 'objects are not aligned', e)
return
self.fail()
def test_dot_fail_1_3(self):
x = numpy.random.rand(5)
y = numpy.random.rand(6,4,7)
try:
z = dot(x,y)
except ValueError, e:
self.failUnless(str(e) == 'objects are not aligned', e)
return
self.fail()
def test_dot_fail_2_1(self):
x = numpy.random.rand(5,4)
y = numpy.random.rand(6)
try:
z = dot(x,y)
except ValueError, e:
self.failUnless(str(e) == 'objects are not aligned', e)
return
self.fail()
def test_dot_fail_2_2(self):
x = numpy.random.rand(5,4)
y = numpy.random.rand(6,7)
try:
z = dot(x,y)
except ValueError, e:
self.failUnless(str(e) == 'objects are not aligned', e)
return
self.fail()
def test_dot_fail_2_3(self):
x = numpy.random.rand(5,4)
y = numpy.random.rand(6,7,8)
try:
z = dot(x,y)
except ValueError, e:
self.failUnless(str(e) == 'objects are not aligned', e)
return
self.fail()
def test_dot_fail_3_1(self):
x = numpy.random.rand(5,4,3)
y = numpy.random.rand(6)
try:
z = dot(x,y)
except ValueError, e:
self.failUnless(str(e) == 'objects are not aligned', e)
return
self.fail()
def test_dot_fail_3_2(self):
x = numpy.random.rand(5,4,3)
y = numpy.random.rand(6,7)
try:
z = dot(x,y)
except ValueError, e:
self.failUnless(str(e) == 'objects are not aligned', e)
return
self.fail()
def test_dot_fail_3_3(self):
x = numpy.random.rand(5,4,3)
y = numpy.random.rand(6,7,8)
try:
z = dot(x,y)
except ValueError, e:
self.failUnless(str(e) == 'objects are not aligned', e)
return
self.fail()
class gemm(omega_op): class gemm(omega_op):
def destroy_map(self): return {self.out:[self.inputs[0]]} def destroy_map(self): return {self.out:[self.inputs[0]]}
...@@ -1640,14 +1738,22 @@ class _testCase_power(unittest.TestCase): ...@@ -1640,14 +1738,22 @@ class _testCase_power(unittest.TestCase):
numpy.random.seed(44) numpy.random.seed(44)
def tearDown(self): def tearDown(self):
pop_mode() pop_mode()
def test1(self):
r = numpy.random.rand(50)
exp_r = exp(r)
self.failUnless(exp_r.__array__().__class__ is numpy.ndarray)
def test_0(self): def test_0(self):
r = numpy.random.rand(50) r = numpy.random.rand(50)
exp_r = exp(r) exp_r = exp(r)
self.failUnless( _approx_eq(exp_r, numpy.exp(r))) n_exp_r = numpy.exp(r)
self.failUnless( _approx_eq(exp_r, n_exp_r),
(exp_r, exp_r.data, n_exp_r,
numpy.max(numpy.abs(n_exp_r.__sub__(exp_r.__array__())))))
log_exp_r = log(exp_r) log_exp_r = log(exp_r)
self.failUnless( _approx_eq(log_exp_r, r)) self.failUnless( _approx_eq(log_exp_r, r), log_exp_r)
def test_1(self): def test_1(self):
r = numpy.random.rand(50) r = numpy.random.rand(50)
......
...@@ -5,7 +5,7 @@ import graph ...@@ -5,7 +5,7 @@ import graph
from utils import ClsInit from utils import ClsInit
from err import GofError, GofTypeError, PropagationError from err import GofError, GofTypeError, PropagationError
from op import Op from op import Op
from result import Result from result import is_result
from features import Listener, Orderings, Constraint, Tool from features import Listener, Orderings, Constraint, Tool
import utils import utils
...@@ -14,10 +14,10 @@ __all__ = ['InconsistencyError', ...@@ -14,10 +14,10 @@ __all__ = ['InconsistencyError',
# class AliasDict(dict): # class AliasDict(dict):
# "Utility class to keep track of what Result has been replaced with what Result." # "Utility class to keep track of what result has been replaced with what result."
# def group(self, main, *keys): # def group(self, main, *keys):
# "Marks all the keys as having been replaced by the Result main." # "Marks all the keys as having been replaced by the result main."
# keys = [key for key in keys if key is not main] # keys = [key for key in keys if key is not main]
# if self.has_key(main): # if self.has_key(main):
# raise Exception("Only group results that have not been grouped before.") # raise Exception("Only group results that have not been grouped before.")
...@@ -284,9 +284,11 @@ class Env(graph.Graph): ...@@ -284,9 +284,11 @@ class Env(graph.Graph):
be raised if there are type mismatches. be raised if there are type mismatches.
""" """
# Assert that they are Result instances. # Assert that they are result instances.
assert isinstance(r, Result) if not is_result(r):
assert isinstance(new_r, Result) raise TypeError(r)
if not is_result(new_r):
raise TypeError(new_r)
# Save where we are so we can backtrack # Save where we are so we can backtrack
if consistency_check: if consistency_check:
......
from op import Op from op import Op
from result import Result #, HolderResult from result import Result, is_result
from utils import ClsInit, Keyword, AbstractFunctionError from utils import ClsInit, Keyword, AbstractFunctionError
import opt import opt
import env import env
...@@ -8,15 +8,14 @@ import features ...@@ -8,15 +8,14 @@ import features
import ext import ext
__all__ = ['UNCOMPUTED', __all__ = [ 'UNDEFINED',
'UNDEFINED',
'current_mode', 'current_mode',
'set_mode', 'set_mode',
'build_mode', 'build_mode',
'eval_mode', 'eval_mode',
'build_eval_mode', 'build_eval_mode',
'pop_mode', 'pop_mode',
'ResultValue', #'ResultValue',
'DummyOp', 'DummyOp',
'DummyRemover', 'DummyRemover',
'PythonOp', 'PythonOp',
...@@ -24,7 +23,6 @@ __all__ = ['UNCOMPUTED', ...@@ -24,7 +23,6 @@ __all__ = ['UNCOMPUTED',
'make_static'] 'make_static']
UNCOMPUTED = Keyword("UNCOMPUTED", False)
UNDEFINED = Keyword("UNDEFINED", False) UNDEFINED = Keyword("UNDEFINED", False)
def make_static(cls, fname): def make_static(cls, fname):
...@@ -58,11 +56,6 @@ def compute(*nodes): ...@@ -58,11 +56,6 @@ def compute(*nodes):
"""Recursively evaluate each node (in a quick & dirty way).""" """Recursively evaluate each node (in a quick & dirty way)."""
compute_from(nodes, set()) compute_from(nodes, set())
def is_result(obj):
"""Return True iff obj provides the interface of a Result"""
attr_list = 'data', 'owner'
return all([hasattr(obj, attr) for attr in attr_list])
class ForbidConstantOverwrite(features.Listener, features.Constraint): class ForbidConstantOverwrite(features.Listener, features.Constraint):
def __init__(self, env): def __init__(self, env):
...@@ -105,80 +98,81 @@ class ForbidConstantOverwrite(features.Listener, features.Constraint): ...@@ -105,80 +98,81 @@ class ForbidConstantOverwrite(features.Listener, features.Constraint):
class ResultValue(Result): if 0:
"""Augment Result to wrap a computed value. class ResultValue(Result):
"""Augment Result to wrap a computed value.
Attributes: Attributes:
data - data -
spec - spec -
constant - constant -
up_to_date - up_to_date -
Properties: Properties:
Methods: Methods:
set_value_filter - ABSTRACT set_value_filter - ABSTRACT
set_value_inplace - ABSTRACT set_value_inplace - ABSTRACT
alloc - ABSTRACT alloc - ABSTRACT
Notes: Notes:
""" """
__slots__ = ['data', 'spec', 'constant', 'up_to_date'] __slots__ = ['data', 'spec', 'constant', 'up_to_date']
def __init__(self, x = UNCOMPUTED, constant = False):
self.constant = False
self.set_value(x) # allow set_value before constant = True
self.constant = constant
self.up_to_date = True
self.refresh() # to set spec
def __str__(self): return str(self.data) def __init__(self, x = None, constant = False):
self.constant = False
self.set_value(x) # allow set_value before constant = True
self.constant = constant
self.up_to_date = True
self.refresh() # to set spec
def __str__(self): return str(self.data)
def __repr__(self): return repr(self.data) def __repr__(self): return repr(self.data)
#TODO: document this function, what does it do? #TODO: document this function, what does it do?
def refresh(self): self.spec = id(self.data) def refresh(self): self.spec = id(self.data)
#################################################### ####################################################
# #
# Functionality provided by this class # Functionality provided by this class
# #
def set_value(self, value): def set_value(self, value):
if self.constant: if self.constant:
raise Exception("This Result is a constant. Its value cannot be changed.") raise Exception("This Result is a constant. Its value cannot be changed.")
if value is None or value is UNCOMPUTED: if value is None or value is UNCOMPUTED:
self.data = UNCOMPUTED self.data = UNCOMPUTED
elif isinstance(value, ResultValue): elif is_result(value):
self.set_value(value.data) self.set_value(value.data)
else: else:
try: try:
self.data = self.set_value_filter(value) self.data = self.set_value_filter(value)
except AbstractFunctionError, e: except AbstractFunctionError, e:
self.data = value self.data = value
self.up_to_date = True self.up_to_date = True
self.refresh() self.refresh()
#################################################### ####################################################
# #
# Pure virtual functions for subclasses to implement # Pure virtual functions for subclasses to implement
# #
# Perform error checking or automatic conversion of value, and return the # Perform error checking or automatic conversion of value, and return the
# result (which will be stored as self.data) # result (which will be stored as self.data)
# Called by: set_value() # Called by: set_value()
def set_value_filter(self, value): raise AbstractFunctionError() def set_value_filter(self, value): raise AbstractFunctionError()
# For mutable data types, overwrite the current contents with value # For mutable data types, overwrite the current contents with value
# Also, call refresh and set up_to_date = True # Also, call refresh and set up_to_date = True
def set_value_inplace(self, value): raise AbstractFunctionError() def set_value_inplace(self, value): raise AbstractFunctionError()
# Instantiate data (according to spec) # Instantiate data (according to spec)
def alloc(self): raise AbstractFunctionError() def alloc(self): raise AbstractFunctionError()
class DestroyHandler(features.Listener, features.Constraint, features.Orderings): class DestroyHandler(features.Listener, features.Constraint, features.Orderings):
...@@ -467,7 +461,8 @@ class PythonOp(Op): ...@@ -467,7 +461,8 @@ class PythonOp(Op):
return all([ is_result(i) for i in self.inputs]) return all([ is_result(i) for i in self.inputs])
def gen_outputs(self): def gen_outputs(self):
return [ResultValue() for i in xrange(self.nout)] raise NotImplementedError()
#return [ResultValue() for i in xrange(self.nout)]
def view_map(self): return {} def view_map(self): return {}
...@@ -500,7 +495,7 @@ class PythonOp(Op): ...@@ -500,7 +495,7 @@ class PythonOp(Op):
return answer return answer
def check_input(self, input): def check_input(self, input):
if input.data is UNCOMPUTED: if input.data is None:
raise ValueError("Uncomputed input: %s in %s" % (input, self)) raise ValueError("Uncomputed input: %s in %s" % (input, self))
if not self.input_is_up_to_date(input): if not self.input_is_up_to_date(input):
raise ValueError("Input is out of date: %s in %s" % (input, self)) raise ValueError("Input is out of date: %s in %s" % (input, self))
......
...@@ -10,7 +10,7 @@ from err import GofError ...@@ -10,7 +10,7 @@ from err import GofError
from utils import AbstractFunctionError from utils import AbstractFunctionError
__all__ = ['Result', 'BrokenLink', 'BrokenLinkError'] __all__ = ['is_result', 'Result', 'BrokenLink', 'BrokenLinkError']
class BrokenLink: class BrokenLink:
...@@ -40,6 +40,11 @@ class BrokenLinkError(GofError): ...@@ -40,6 +40,11 @@ class BrokenLinkError(GofError):
# Result # Result
############################ ############################
def is_result(obj):
"""Return True iff obj provides the interface of a Result"""
attr_list = 'owner',
return all([hasattr(obj, attr) for attr in attr_list])
class Result(object): class Result(object):
"""Storage node for data in a graph of Op instances. """Storage node for data in a graph of Op instances.
......
...@@ -57,32 +57,30 @@ class Grad(object): ...@@ -57,32 +57,30 @@ class Grad(object):
r may be uncomputed or NumpyR r may be uncomputed or NumpyR
""" """
if dr is core.UNDEFINED: if dr is core.UNDEFINED:
# nothing to do # nothing to do
pass return
if r.data is not None and dr.data is not None:
if not hasattr(r, 'shape'):
raise ValueError(('Grad::add r lacks shape: type=',
type(r)))
if not hasattr(dr, 'shape'):
raise ValueError(('Grad::add dr lacks shape: type=',
type(dr)))
if r.shape != dr.shape:
raise ValueError(('Grad::add r, dr shape mismatch',
r.shape, dr.shape))
# prevent 'r' from being re-calculated by self.__call__ in 'build_eval' mode
if r.computed:
self._compute_history.add(r)
# add dr to self[r]
if r in self:
self[r] = self[r] + dr
else: else:
if r.data is core.UNCOMPUTED or dr.data is core.UNCOMPUTED: self[r] = dr
pass # no sanity checking
else: # some sanity checking to catch obvious mistakes
if not hasattr(r.data, 'shape'):
raise ValueError(('Grad::add r lacks shape: type=',
type(r.data)))
if not hasattr(dr.data, 'shape'):
raise ValueError(('Grad::add dr lacks shape: type=',
type(dr.data)))
if r.data.shape != dr.data.shape:
raise ValueError(('Grad::add r, dr shape mismatch',
r.data.shape, dr.data.shape))
# prevent 'r' from being re-calculated by self.__call__ in 'build_eval' mode
if r.data is not core.UNCOMPUTED:
self._compute_history.add(r)
# add dr to self[r]
if r in self:
self[r] = self[r] + dr
else:
self[r] = dr
def bprop(self, maybe_redo=False): def bprop(self, maybe_redo=False):
"""Build a backpropagation graph. """Build a backpropagation graph.
...@@ -213,14 +211,14 @@ class _testCase (unittest.TestCase): ...@@ -213,14 +211,14 @@ class _testCase (unittest.TestCase):
diff = wwi - ident diff = wwi - ident
ssdiff = core.sum((diff**2)) ssdiff = core.sum((diff**2))
if i == 0: if i == 0:
str0 = str_ssdiff = str(ssdiff) str0 = str_ssdiff = str(ssdiff.data)
#print ssdiff #print ssdiff
g = grad(ssdiff) g = grad(ssdiff)
gw = g(w) gw = g(w)
w.data += -0.4 * gw.data w.data[:] += -0.4 * gw.data
return str0, str(ssdiff) return str0, str(ssdiff.data)
def matinv_compiled(self, dim): def matinv_compiled(self, dim):
w = core.wrap(numpy.random.rand(dim,dim)) w = core.wrap(numpy.random.rand(dim,dim))
...@@ -230,7 +228,7 @@ class _testCase (unittest.TestCase): ...@@ -230,7 +228,7 @@ class _testCase (unittest.TestCase):
wwi = core.dot(w, wi) wwi = core.dot(w, wi)
diff = wwi - ident diff = wwi - ident
ssdiff = core.sum((diff**2)) ssdiff = core.sum((diff**2))
str0 = str_ssdiff = str(ssdiff) str0 = str_ssdiff = str(ssdiff.data)
#print ssdiff #print ssdiff
g = grad(ssdiff) g = grad(ssdiff)
...@@ -240,9 +238,9 @@ class _testCase (unittest.TestCase): ...@@ -240,9 +238,9 @@ class _testCase (unittest.TestCase):
for i in xrange(300): for i in xrange(300):
prog() prog()
w.data += -0.4 * gw.data w.data[:] += -0.4 * gw.data
return str0, str(ssdiff) return str0, str(ssdiff.data)
def test0(self): def test0(self):
"""Matrix inversion by gradient descent (eval mode)""" """Matrix inversion by gradient descent (eval mode)"""
......
...@@ -9,7 +9,7 @@ import grad ...@@ -9,7 +9,7 @@ import grad
# Wrapper type # Wrapper type
class SparseR(gof.ResultValue): class SparseR(core.ResultBase):
""" """
Attribute: Attribute:
format - a subclass of sparse.spmatrix indicating self.data.__class__ format - a subclass of sparse.spmatrix indicating self.data.__class__
...@@ -22,12 +22,17 @@ class SparseR(gof.ResultValue): ...@@ -22,12 +22,17 @@ class SparseR(gof.ResultValue):
Notes: Notes:
""" """
def __init__(self, x = core.UNCOMPUTED, constant = False, def __init__(self, data=None, role=None, constant = False,
format = sparse.csr_matrix): format = sparse.csr_matrix):
gof.ResultValue.__init__(self, x, constant) core.ResultBase.__init__(self, role, data, constant)
self.format = isinstance(x, sparse.spmatrix) and x.__class__ or format if isinstance(data, sparse.spmatrix):
self.format = data.__class__
else:
self.format = format
self._dtype = None
self._shape = None
def set_value_filter(self, value): def data_filter(self, value):
if isinstance(value, sparse.spmatrix): return value if isinstance(value, sparse.spmatrix): return value
return sparse.csr_matrix(value) return sparse.csr_matrix(value)
...@@ -36,6 +41,32 @@ class SparseR(gof.ResultValue): ...@@ -36,6 +41,32 @@ class SparseR(gof.ResultValue):
T = property(lambda self: transpose(self), doc = "Return aliased transpose") T = property(lambda self: transpose(self), doc = "Return aliased transpose")
# self._dtype is used when self._data hasn't been set yet
def __dtype_get(self):
if self._data is None:
return self._dtype
else:
return self._data.dtype
def __dtype_set(self, dtype):
if self._data is None:
self._dtype = dtype
else:
raise StateError('cannot set dtype after data has been set')
dtype = property(__dtype_get, __dtype_set)
# self._shape is used when self._data hasn't been set yet
def __shape_get(self):
if self._data is None:
return self._shape
else:
return self._data.shape
def __shape_set(self, shape):
if self._data is None:
self._shape = shape
else:
raise StateError('cannot set shape after data has been set')
shape = property(__shape_get, __shape_set)
# convenience base class # convenience base class
class op(gof.PythonOp, grad.update_gradient_via_grad): class op(gof.PythonOp, grad.update_gradient_via_grad):
pass pass
...@@ -46,7 +77,7 @@ class op(gof.PythonOp, grad.update_gradient_via_grad): ...@@ -46,7 +77,7 @@ class op(gof.PythonOp, grad.update_gradient_via_grad):
# convert a sparse matrix to an ndarray # convert a sparse matrix to an ndarray
class sparse2dense(op): class sparse2dense(op):
def gen_outputs(self): return [core.NumpyR()] def gen_outputs(self): return [core.Numpy2()]
def impl(x): return numpy.asarray(x.todense()) def impl(x): return numpy.asarray(x.todense())
def grad(self, x, gz): def grad(self, x, gz):
if x.format is sparse.coo_matrix: return dense2coo(gz) if x.format is sparse.coo_matrix: return dense2coo(gz)
...@@ -135,7 +166,7 @@ class _testCase_dot(unittest.TestCase): ...@@ -135,7 +166,7 @@ class _testCase_dot(unittest.TestCase):
def test_basic0(self): def test_basic0(self):
for mtype in [sparse.csc_matrix, sparse.csr_matrix]: for mtype in [sparse.csc_matrix, sparse.csr_matrix]:
x = SparseR(mtype(sparse.speye(5,3))) x = SparseR(mtype(sparse.speye(5,3)))
y = core.NumpyR(numpy.random.rand(3, 2)) y = core.wrap(numpy.random.rand(3, 2))
z = dot(x,y) z = dot(x,y)
self.failUnless(z.data.shape == (5,2)) self.failUnless(z.data.shape == (5,2))
...@@ -171,7 +202,7 @@ class _testCase_dot(unittest.TestCase): ...@@ -171,7 +202,7 @@ class _testCase_dot(unittest.TestCase):
self.failUnless(z.data.shape == ab.shape) self.failUnless(z.data.shape == ab.shape)
self.failUnless(type(z.data) == type(ab)) self.failUnless(type(z.data) == type(ab))
def test_graph_bprop0(self): def test_graph_bprop0(self):
x = core.NumpyR(numpy.random.rand(10,2)) x = core.wrap(numpy.random.rand(10,2))
w = SparseR(sparse.csr_matrix(numpy.asarray([[1, 0, 3, 0, 5], [0, 0, -2, 0, w = SparseR(sparse.csr_matrix(numpy.asarray([[1, 0, 3, 0, 5], [0, 0, -2, 0,
0]],dtype='float64'))) 0]],dtype='float64')))
...@@ -187,10 +218,10 @@ class _testCase_dot(unittest.TestCase): ...@@ -187,10 +218,10 @@ class _testCase_dot(unittest.TestCase):
g(w).data[1,4] = 0 g(w).data[1,4] = 0
w.data = -lr * g(w).data + w.data w.data = -lr * g(w).data + w.data
self.failUnless('3.08560636025' == str(loss)) self.failUnless('3.08560636025' == str(loss.data))
def test_graph_bprop1(self): def test_graph_bprop1(self):
x = core.NumpyR(numpy.random.rand(10,2)) x = core.wrap(numpy.random.rand(10,2))
w = SparseR(sparse.csr_matrix(numpy.asarray([[1, 0, 3, 0, 5], [0, 0, -2, 0, w = SparseR(sparse.csr_matrix(numpy.asarray([[1, 0, 3, 0, 5], [0, 0, -2, 0,
0]],dtype='float64'))) 0]],dtype='float64')))
...@@ -205,7 +236,7 @@ class _testCase_dot(unittest.TestCase): ...@@ -205,7 +236,7 @@ class _testCase_dot(unittest.TestCase):
g(w).data[1,4] = 0 g(w).data[1,4] = 0
w.data = -lr * g(w).data + w.data w.data = -lr * g(w).data + w.data
self.failUnless('3.08560636025' == str(loss)) self.failUnless('3.08560636025' == str(loss.data))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论