fixed and passing again

上级 f94c222c
...@@ -12,7 +12,7 @@ from scipy import weave ...@@ -12,7 +12,7 @@ from scipy import weave
import gof import gof
from gof import current_mode, set_mode, build_mode, eval_mode, build_eval_mode from gof import current_mode, set_mode, build_mode, eval_mode, build_eval_mode
from gof import pop_mode, UNDEFINED, is_result from gof import pop_mode, is_result
import type_spec import type_spec
import cutils import cutils
...@@ -251,9 +251,6 @@ class Numpy2(ResultBase): ...@@ -251,9 +251,6 @@ class Numpy2(ResultBase):
__array__ = property(lambda self: self._data.__array__ ) __array__ = property(lambda self: self._data.__array__ )
__array_struct__ = property(lambda self: self._data.__array_struct__ ) __array_struct__ = property(lambda self: self._data.__array_struct__ )
def data_set_inplace(self, data):
raise NotImplementedError()
def data_alloc(self): def data_alloc(self):
self.data = numpy.ndarray(self.shape, self.dtype) self.data = numpy.ndarray(self.shape, self.dtype)
...@@ -386,7 +383,7 @@ def input(x): ...@@ -386,7 +383,7 @@ def input(x):
elif is_result(x): elif is_result(x):
raise TypeError("%s is already a result." % x) raise TypeError("%s is already a result." % x)
else: else:
return ResultBase(x) return ResultBase(data=x)
class _testCase_input(unittest.TestCase): class _testCase_input(unittest.TestCase):
def setUp(self): def setUp(self):
literal.hdb = {} literal.hdb = {}
...@@ -1067,11 +1064,14 @@ if 0: ...@@ -1067,11 +1064,14 @@ if 0:
from grad import Grad
def wrap_producer(f): def wrap_producer(f):
class producer(omega_op): class producer(omega_op):
impl = f impl = f
def grad(*args): def grad(*args):
return [UNDEFINED] * (len(args) - 1) return [Grad.Undefined] * (len(args) - 1)
producer.__name__ = f.__name__ producer.__name__ = f.__name__
def ret(dim, dtype = 'float', order = 'C'): def ret(dim, dtype = 'float', order = 'C'):
return producer(dim, dtype, order) return producer(dim, dtype, order)
...@@ -1811,11 +1811,11 @@ class sum(elemwise): ...@@ -1811,11 +1811,11 @@ class sum(elemwise):
class ones_like(elemwise): class ones_like(elemwise):
impl = numpy.ones_like impl = numpy.ones_like
def grad(x, gz): return UNDEFINED def grad(x, gz): return Grad.Undefined
class zeros_like(elemwise): class zeros_like(elemwise):
impl = numpy.zeros_like impl = numpy.zeros_like
def grad(x, gz): return UNDEFINED def grad(x, gz): return Grad.Undefined
## Array slicing ## ## Array slicing ##
......
...@@ -17,6 +17,8 @@ class Grad(object): ...@@ -17,6 +17,8 @@ class Grad(object):
__call__() __call__()
__getitem__() __getitem__()
""" """
class Undefined: pass
def __init__(self, dct={}): def __init__(self, dct={}):
self.map = {} self.map = {}
self.outputs = [] self.outputs = []
...@@ -34,7 +36,7 @@ class Grad(object): ...@@ -34,7 +36,7 @@ class Grad(object):
try: try:
return self.map[key] return self.map[key]
except KeyError: except KeyError:
return core.UNDEFINED return Grad.Undefined
def __setitem__(self, item, val): def __setitem__(self, item, val):
"""Map item to its id and store internally.""" """Map item to its id and store internally."""
...@@ -49,7 +51,7 @@ class Grad(object): ...@@ -49,7 +51,7 @@ class Grad(object):
This function should be fed as follows: This function should be fed as follows:
if dr is UNDEFINED: if dr is undefined:
r could be anything r could be anything
else dr might be core.UNCOMPUTED: else dr might be core.UNCOMPUTED:
r may be uncomputed or NumpyR r may be uncomputed or NumpyR
...@@ -57,7 +59,7 @@ class Grad(object): ...@@ -57,7 +59,7 @@ class Grad(object):
r may be uncomputed or NumpyR r may be uncomputed or NumpyR
""" """
if dr is core.UNDEFINED: if dr is Grad.Undefined:
# nothing to do # nothing to do
return return
...@@ -122,7 +124,7 @@ class Grad(object): ...@@ -122,7 +124,7 @@ class Grad(object):
if not self.did_bprop: if not self.did_bprop:
raise Exception('Grad.__call__ only makes sense after a bprop') raise Exception('Grad.__call__ only makes sense after a bprop')
rval = self[item] rval = self[item]
if rval is not core.UNDEFINED \ if rval is not Grad.Undefined \
and core.current_mode() == 'build_eval': and core.current_mode() == 'build_eval':
compute_from([rval], self._compute_history) compute_from([rval], self._compute_history)
return rval return rval
...@@ -297,7 +299,7 @@ class _testCase (unittest.TestCase): ...@@ -297,7 +299,7 @@ class _testCase (unittest.TestCase):
gb.bprop() gb.bprop()
self.assertEqual('should have raised',0) self.assertEqual('should have raised',0)
except AttributeError, e: except AttributeError, e:
self.assertEqual(str(e), "Keyword instance has no attribute 'shape'") self.assertEqual(str(e), "class Undefined has no attribute 'shape'")
return return
self.assertEqual("Should have been error", 0) self.assertEqual("Should have been error", 0)
...@@ -311,7 +313,7 @@ class _testCase (unittest.TestCase): ...@@ -311,7 +313,7 @@ class _testCase (unittest.TestCase):
gc.bprop() gc.bprop()
self.assertEqual('should have raised',0) self.assertEqual('should have raised',0)
except AttributeError, e: except AttributeError, e:
self.assertEqual(str(e), "Keyword instance has no attribute 'shape'") self.assertEqual(str(e), "class Undefined has no attribute 'shape'")
return return
self.assertEqual("Should have been error", 0) self.assertEqual("Should have been error", 0)
......
...@@ -182,7 +182,7 @@ class _testCase_dot(unittest.TestCase): ...@@ -182,7 +182,7 @@ class _testCase_dot(unittest.TestCase):
m = mtype(a) m = mtype(a)
ab = m.dot(b) ab = m.dot(b)
try: try:
z = dot(SparseR(m),gof.lib.ResultValue(b)) z = dot(SparseR(m),core.ResultBase(data=b))
self.failUnless(z.data.shape == ab.shape) self.failUnless(z.data.shape == ab.shape)
self.failUnless(type(z.data) == type(ab)) self.failUnless(type(z.data) == type(ab))
except Exception, e: except Exception, e:
...@@ -198,7 +198,7 @@ class _testCase_dot(unittest.TestCase): ...@@ -198,7 +198,7 @@ class _testCase_dot(unittest.TestCase):
sparse.lil_matrix]:#, sparse.coo_matrix]: sparse.lil_matrix]:#, sparse.coo_matrix]:
m = mtype(b) m = mtype(b)
ab = m.transpose().dot(a.transpose()).transpose() ab = m.transpose().dot(a.transpose()).transpose()
z = dot(gof.lib.ResultValue(a),SparseR(mtype(b))) z = dot(core.ResultBase(data=a),SparseR(mtype(b)))
self.failUnless(z.data.shape == ab.shape) self.failUnless(z.data.shape == ab.shape)
self.failUnless(type(z.data) == type(ab)) self.failUnless(type(z.data) == type(ab))
def test_graph_bprop0(self): def test_graph_bprop0(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论