提交 4a6845b1 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

added test files and moved stuff out of core.py

上级 81c7aa5c
差异被折叠。
import unittest
from wrappers import *
class _testCase_input(unittest.TestCase):
def setUp(self):
literal.hdb = {}
literal.udb = {}
def test_input_int(self):
w = input(3)
self.failUnless(isinstance(w, input.NN))
self.failUnless(str(w.data.dtype) == input.int_dtype)
self.failUnless(w.data == 3)
def test_input_float(self):
w = input(3.0)
self.failUnless(isinstance(w, input.NN))
self.failUnless(str(w.data.dtype) == input.float_dtype)
self.failUnless(w.data == 3.0)
class _testCase_wrap(unittest.TestCase):
def setUp(self):
literal.hdb = {}
literal.udb = {}
def test_wrap_int(self):
w = wrap(3)
self.failUnless(isinstance(w, input.NN))
self.failUnless(str(w.data.dtype) == input.int_dtype)
self.failUnless(w.data == 3)
def test_wrap_float(self):
w = wrap(3.0)
self.failUnless(isinstance(w, input.NN))
self.failUnless(str(w.data.dtype) == input.float_dtype)
self.failUnless(w.data == 3.0)
class _testCase_literal(unittest.TestCase):
def setUp(self):
literal.hdb = {}
literal.udb = {}
def test_int(self):
w = literal(3)
self.failUnless(isinstance(w, input.NN))
self.failUnless(str(w.data.dtype) == input.int_dtype)
self.failUnless(w.data == 3)
u = literal(1+2)
self.failUnless(u is w)
def test_float(self):
w = literal(3.0)
self.failUnless(isinstance(w, input.NN))
self.failUnless(str(w.data.dtype) == input.float_dtype)
self.failUnless(w.data == 3.0)
u = literal(1.0+2.0)
self.failUnless(u is w)
def test_mixed(self):
f = literal(2.0)
i = literal(2)
self.failUnless(i is not f)
if __name__ == '__main__':
unittest.main()
差异被折叠。
差异被折叠。
差异被折叠。
from core import Numpy2, omega_op
def input(x):
#static member initialization
if not hasattr(input, 'float_dtype'):
input.float_dtype = 'float64'
input.int_dtype = 'int64'
input.NN = Numpy2
if isinstance(x, numpy.ndarray):
#return NumpyR(x)
return input.NN(data=x)
elif isinstance(x, int):
z = numpy.zeros((), dtype = input.int_dtype)
z += x
return input.NN(data=z)
elif isinstance(x, float):
z = numpy.zeros((), dtype = input.float_dtype)
z += x
return input.NN(data=z)
elif is_result(x):
raise TypeError("%s is already a result." % x)
else:
return ResultBase(data=x)
def wrap(x):
if isinstance(x, Numpy2):
return x
#elif isinstance(x, NumpyR):
#return x
elif is_result(x):
return x
elif isinstance(x, omega_op):
return x.out
else:
return literal(x)
def literal(x):
"""Return a ResultValue instance wrapping a literal."""
def _hashable(x):
try:
x in {}
return True
except TypeError: # x is unhashable
return False
#static member initialization
if not hasattr(literal, 'hdb'):
literal.hdb = {}
literal.udb = {}
if _hashable(x):
db = literal.hdb
key = (type(x),x)
else:
db = literal.udb
key = (id(x),)
if key in db:
return db[key]
else:
rval = input(x)
rval.constant = True
db[key] = rval
return rval
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论