fixed core.literal(), all tests pass including getslice!

上级 cd8be514
...@@ -48,14 +48,6 @@ def _approx_eq(a,b,eps=1.0e-9): ...@@ -48,14 +48,6 @@ def _approx_eq(a,b,eps=1.0e-9):
return numpy.all(d < eps) return numpy.all(d < eps)
literals_db = {}
#literals_id_db = weakref.WeakValueDictionary()
literals_id_db = {}
#input floating point scalars will be cast to arrays of this type
# see TRAC(#31)
default_input_scalar_dtype = 'float64'
@blas._constant # TODO: move this decorator to a utility script @blas._constant # TODO: move this decorator to a utility script
def _compile_dir(): def _compile_dir():
"""Return the directory in which scipy.weave should store code objects. """Return the directory in which scipy.weave should store code objects.
...@@ -99,8 +91,12 @@ def input(x): ...@@ -99,8 +91,12 @@ def input(x):
# being cast to floating-point (can that cause incorrectness?) # being cast to floating-point (can that cause incorrectness?)
if isinstance(x, numpy.ndarray): if isinstance(x, numpy.ndarray):
return NumpyR(x) return NumpyR(x)
elif isinstance(x, (int, float)): elif isinstance(x, int):
z = numpy.zeros((), dtype = default_input_scalar_dtype) z = numpy.zeros((), dtype = input.int_dtype)
z += x
return NumpyR(z)
elif isinstance(x, float):
z = numpy.zeros((), dtype = input.float_dtype)
z += x z += x
return NumpyR(z) return NumpyR(z)
elif isinstance(x, gof.Result): elif isinstance(x, gof.Result):
...@@ -108,6 +104,9 @@ def input(x): ...@@ -108,6 +104,9 @@ def input(x):
else: else:
return ResultValue(x) return ResultValue(x)
input.float_dtype = 'float64'
input.int_dtype = 'int64'
def wrap(x): def wrap(x):
if isinstance(x, NumpyR): if isinstance(x, NumpyR):
return x return x
...@@ -118,38 +117,69 @@ def wrap(x): ...@@ -118,38 +117,69 @@ def wrap(x):
else: else:
return literal(x) return literal(x)
def _hashable(x): class _testCase_wrap(unittest.TestCase):
def setUp(self):
literal.hdb = {}
literal.udb = {}
def test_input_int(self):
w = input(3)
self.failUnless(isinstance(w, NumpyR))
self.failUnless(str(w.data.dtype) == input.int_dtype)
self.failUnless(w.data == 3)
def test_input_float(self):
w = input(3.0)
self.failUnless(isinstance(w, NumpyR))
self.failUnless(str(w.data.dtype) == input.float_dtype)
self.failUnless(w.data == 3.0)
def test_literal_int(self):
w = literal(3)
self.failUnless(isinstance(w, NumpyR))
self.failUnless(str(w.data.dtype) == input.int_dtype)
self.failUnless(w.data == 3)
def test_literal_float(self):
w = literal(3.0)
self.failUnless(isinstance(w, NumpyR))
self.failUnless(str(w.data.dtype) == input.float_dtype)
self.failUnless(w.data == 3.0)
def test_wrap_int(self):
w = wrap(3)
self.failUnless(isinstance(w, NumpyR))
self.failUnless(str(w.data.dtype) == input.int_dtype)
self.failUnless(w.data == 3)
def test_wrap_float(self):
w = wrap(3.0)
self.failUnless(isinstance(w, NumpyR))
self.failUnless(str(w.data.dtype) == input.float_dtype)
self.failUnless(w.data == 3.0)
def literal(x):
"""Return a ResultValue instance wrapping a literal."""
def _hashable(x):
try: try:
x in {} x in {}
return True return True
except TypeError: # x is unhashable except TypeError: # x is unhashable
return False return False
def _literal_hashable(x): #static member initialization
if x in literals_db: if not hasattr(literal, 'hdb'):
return literals_db[x] literal.hdb = {}
else: literal.udb = {}
r = input(x)
r.constant = True
literals_db[x] = r
return r
def _literal_unhashable(x): if _hashable(x):
idx = id(x) db = literal.hdb
if idx in literals_id_db: key = (id(x),x)
return literals_id_db[idx]
else: else:
r = input(x) db = literal.udb
r.constant = True key = (id(x),)
literals_id_db[idx] = r
return r
def literal(x): if key in db:
"""Return a ResultValue instance wrapping a literal.""" return db[key]
if _hashable(x):
return _literal_hashable(x)
else: else:
return _literal_unhashable(x) rval = input(x)
rval.constant = True
db[key] = rval
return rval
...@@ -682,9 +712,9 @@ class NumpyR(gof.ResultValue): ...@@ -682,9 +712,9 @@ class NumpyR(gof.ResultValue):
raise ValueError() raise ValueError()
else: else:
if 0 == len(self.data.shape): if 0 == len(self.data.shape):
self.data.itemset(value) self.data.itemset(value) # for scalars
else: else:
self.data[:] = value self.data[:] = value # for matrices
self.refresh() self.refresh()
self.up_to_date = True self.up_to_date = True
...@@ -693,7 +723,8 @@ class NumpyR(gof.ResultValue): ...@@ -693,7 +723,8 @@ class NumpyR(gof.ResultValue):
self.spec = (numpy.ndarray, self.data.dtype, self.data.shape) self.spec = (numpy.ndarray, self.data.dtype, self.data.shape)
def alloc(self): def alloc(self):
self.data = numpy.ndarray(self.spec[2], self.spec[1]) shape, dtype = self.spec[2], self.spec[1]
self.data = numpy.ndarray(shape, dtype=dtype)
def __add__(self, y): return add(self, y) def __add__(self, y): return add(self, y)
def __radd__(self, x): return add(x, self) def __radd__(self, x): return add(x, self)
...@@ -1487,10 +1518,9 @@ class _testCase_slicing(unittest.TestCase): ...@@ -1487,10 +1518,9 @@ class _testCase_slicing(unittest.TestCase):
self.fail('add should not have succeeded') self.fail('add should not have succeeded')
def test_getitem1(self): def test_getitem1(self):
#TODO: re-enable this test
return
a = numpy.ones((4,4)) a = numpy.ones((4,4))
wa1 = wrap(a)[1] wa1 = wrap(a)[1]
self.failUnless(wa1.data.shape == (4,))
def test_getslice_0d_all(self): def test_getslice_0d_all(self):
"""Test getslice does not work on 0d array """ """Test getslice does not work on 0d array """
......
...@@ -126,10 +126,10 @@ class ResultValue(Result): ...@@ -126,10 +126,10 @@ class ResultValue(Result):
def __init__(self, x = UNCOMPUTED, constant = False): def __init__(self, x = UNCOMPUTED, constant = False):
self.constant = False self.constant = False
self.set_value(x) self.set_value(x) # allow set_value before constant = True
self.constant = constant self.constant = constant
self.up_to_date = True self.up_to_date = True
self.spec = None self.refresh() # to set spec
def __str__(self): return str(self.data) def __str__(self): return str(self.data)
......
...@@ -118,6 +118,9 @@ class Op(object): ...@@ -118,6 +118,9 @@ class Op(object):
try: try:
self.validate() self.validate()
except: except:
# this call gives a subclass the chance to undo the set_outputs
# that it may have triggered...
# TODO: test this functionality!
self.set_input(i, previous, True, False) self.set_input(i, previous, True, False)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论