提交 4d32f45e authored 作者: James Bergstra's avatar James Bergstra

comparisons and logical ops almost working...

上级 1c973723
......@@ -84,27 +84,37 @@ class _test_logical(unittest.TestCase):
def test_or(self):
x, y, z = ints('xyz')
fn = gof.DualLinker(Env([x,y], [or_(x, y)])).make_function()
fn = gof.DualLinker(Env([x,y], [x|y])).make_function()
for a,b in ((0,1), (0,0), (1,0), (1,1)):
self.failUnless(fn(a,b) == (a or b), (a,b))
self.failUnless(fn(a,b) == (a|b), (a,b))
def test_xor(self):
x, y, z = ints('xyz')
fn = gof.DualLinker(Env([x,y], [xor(x, y)])).make_function()
fn = gof.DualLinker(Env([x,y], [x^y])).make_function()
for a,b in ((0,1), (0,0), (1,0), (1,1)):
self.failUnless(fn(a,b) == (operator.xor(a, b)), (a,b))
self.failUnless(fn(a,b) == (a ^ b), (a,b))
def test_and(self):
x, y, z = ints('xyz')
fn = gof.DualLinker(Env([x,y], [and_(x, y)])).make_function()
for a,b in ((0,1), (0,0), (1,0), (1,1)):
self.failUnless(fn(a,b) == (a and b), (a,b))
self.failUnless(fn(a,b) == (a & b), (a,b))
x, y, z = ints('xyz')
fn = gof.DualLinker(Env([x,y], [x & y])).make_function()
for a,b in ((0,1), (0,0), (1,0), (1,1)):
self.failUnless(fn(a,b) == (a & b), (a,b))
def test_not(self):
x, y, z = ints('xyz')
fn = gof.DualLinker(Env([x,y], [not_(x)])).make_function()
fn = gof.DualLinker(Env([x,y], [invert(x)])).make_function()
for a,b in ((0,1), (0,0), (1,0), (1,1)):
self.failUnless(fn(a,b) == ~a, (a,))
x, y, z = ints('xyz')
fn = gof.DualLinker(Env([x,y], [~x])).make_function()
for a,b in ((0,1), (0,0), (1,0), (1,1)):
self.failUnless(fn(a,b) == (not a), (a,))
self.failUnless(fn(a,b) == ~a, (a,))
if __name__ == '__main__':
unittest.main()
......
import traceback
import operator
from tensor import *
import tensor # for hidden symbols
......@@ -883,6 +885,81 @@ class T_Stack(unittest.TestCase):
self.failUnless((eval_outputs([s]) == c).all())
class _test_comparison(unittest.TestCase):
def test_gt(self):
x, y = fvector(), fvector()
fn = function([x,y], [x > y])
l = numpy.asarray([0.,-1.,1.])
r = numpy.asarray([0.,1.,-1.])
v = fn(l, r)
self.failUnless(numpy.all(v == (l > r)), (v, (l>r)))
def test_lt(self):
x, y = fvector(), fvector()
fn = function([x,y], [x < y])
l = numpy.asarray([0.,-1.,1.])
r = numpy.asarray([0.,1.,-1.])
v = fn(l, r)
self.failUnless(numpy.all(v == (l < r)), (v, (l<r)))
def test_le(self):
x, y = fvector(), fvector()
fn = function([x,y], [x <= y])
l = numpy.asarray([0.,-1.,1.])
r = numpy.asarray([0.,1.,-1.])
v = fn(l, r)
self.failUnless(numpy.all(v == (l <= r)), (v, (l<=r)))
def test_ge(self):
x, y = fvector(), fvector()
fn = function([x,y], [x >= y])
l = numpy.asarray([0.,-1.,1.])
r = numpy.asarray([0.,1.,-1.])
v = fn(l, r)
self.failUnless(numpy.all(v == (l >= r)), (v, (l>=r)))
class _test_bitwise(unittest.TestCase):
def test_or(self):
x, y = bvector(), bvector()
fn = function([x,y], [x|y])
l = numpy.asarray([0,0,1,1])
r = numpy.asarray([0,1,0,1])
v = fn(l, r)
self.failUnless(numpy.all(v == (operator.or_(l, r))), (l, r, v))
def test_xor(self):
x, y = bvector(), bvector()
fn = function([x,y], [x^y])
ix = x
ix ^= y
gn = function([x,y], [ix])
l = numpy.asarray([0,0,1,1])
r = numpy.asarray([0,1,0,1])
v = fn(l, r)
self.failUnless(numpy.all(v == (operator.xor(l, r))), (l, r, v))
print ' '
print l, type(l)
v = gn(l, r)
#test the in-place stuff
print l, type(l)
print v, type(l)
self.failUnless(numpy.all(l == numpy.asarray([0,1,1,0])), l)
def test_and(self):
x, y = bvector(), bvector()
fn = function([x,y], [x&y])
l = numpy.asarray([0,0,1,1])
r = numpy.asarray([0,1,0,1])
v = fn(l, r)
self.failUnless(numpy.all(v == (operator.and_(l, r))), (l, r, v))
def test_inv(self):
x, y = bvector(), bvector()
fn = function([x,y], [~x])
l = numpy.asarray([0,0,1,1])
r = numpy.asarray([0,1,0,1])
v = fn(l, r)
self.failUnless(numpy.all(v == (~l)), (l, r, v))
......
......@@ -183,6 +183,15 @@ class _scalar_py_operators:
def __float__(self): return AsInt(self).out
def __complex__(self): return AsComplex(self).out
#BITWISE
def __invert__(self): return invert(self)
def __and__(self,other): return and_(self, other)
def __or__(self,other): return or_(self, other)
def __xor__(self,other): return xor(self, other)
def __rand__(self,other): return and_(other,self)
def __ror__(self,other): return or_(other, self)
def __rxor__(self,other): return xor(other, self)
#COMPARISONS
def __lt__(self,other): return lt(self, other)
def __le__(self,other): return le(self, other)
......@@ -326,21 +335,18 @@ class UnaryScalarOp(ScalarOp):
class BinaryScalarOp(ScalarOp):
nin = 2
class UnaryLogicalOp(UnaryScalarOp):
def output_types(self, *input_dtypes):
return [int8]
def grad(self, inputs, output_gradients):
return [None]
class BinaryLogicalOp(BinaryScalarOp):
###############
# Comparisons
###############
class LogicalComparison(BinaryScalarOp):
def output_types(self, *input_dtypes):
return [int8]
def grad(self, inputs, output_gradients):
return [None, None]
class LT(BinaryLogicalOp):
class LT(LogicalComparison):
identity = False
commutative = False
associative = False
......@@ -348,7 +354,7 @@ class LT(BinaryLogicalOp):
return x < y
lt = LT()
class GT(BinaryLogicalOp):
class GT(LogicalComparison):
identity = False
commutative = False
associative = False
......@@ -356,7 +362,7 @@ class GT(BinaryLogicalOp):
return x > y
gt = GT()
class LE(BinaryLogicalOp):
class LE(LogicalComparison):
identity = False
commutative = False
associative = False
......@@ -364,7 +370,7 @@ class LE(BinaryLogicalOp):
return x <= y
le = LE()
class GE(BinaryLogicalOp):
class GE(LogicalComparison):
identity = False
commutative = False
associative = False
......@@ -372,7 +378,7 @@ class GE(BinaryLogicalOp):
return x >= y
ge = GE()
class EQ(BinaryLogicalOp):
class EQ(LogicalComparison):
identity = False
commutative = True
associative = False
......@@ -380,38 +386,64 @@ class EQ(BinaryLogicalOp):
return x == y
eq = EQ()
class OR(BinaryLogicalOp):
####################
# BIT-WISE OPERATORS
####################
class UnaryBitOp(UnaryScalarOp):
def output_types(self, *input_types):
for i in input_types[0]:
if i not in (int8, int32, int64):
raise TypeError('input to a BitOp must have type int8, int32 or int 64... not %s' % i)
return upcast_out(*input_types[0])
def grad(self, inputs, output_gradients):
return [None]
class BinaryBitOp(BinaryScalarOp):
def output_types(self, *input_types):
t0, t1 = input_types[0]
for i in input_types[0]:
if i not in (int8, int32, int64):
raise TypeError('input to a BitOp must have type int8, int32 or int 64... not %s' % i)
return upcast_out(*input_types[0])
def grad(self, inputs, output_gradients):
return [None, None]
class OR(BinaryBitOp):
identity = False
commutative = True
associative = False
def impl(self, x, y):
return (x or y)
return x | y
or_ = OR()
class XOR(BinaryLogicalOp):
class XOR(BinaryBitOp):
identity = False
commutative = True
associative = False
def impl(self, x, y):
return operator.xor(x, y)
return x ^ y
xor = XOR()
class AND(BinaryLogicalOp):
class AND(BinaryBitOp):
identity = False
commutative = True
associative = False
def impl(self, x, y):
return x and y
return x & y
and_ = AND()
class NOT(UnaryLogicalOp):
class Invert(UnaryBitOp):
identity = False
def impl(self, x):
return not x
not_ = NOT()
return ~x
invert = Invert()
##############
# Arithmetic
##############
class Add(ScalarOp):
identity = 0
......
......@@ -252,6 +252,7 @@ def _multi(*fns):
fscalar = Tensor('float32', ())
dscalar = Tensor('float64', ())
bscalar = Tensor('int8', ())
iscalar = Tensor('int32', ())
lscalar = Tensor('int64', ())
def scalar(name = None, dtype = 'float64'):
......@@ -261,6 +262,7 @@ scalars, fscalars, dscalars, iscalars, lscalars = _multi(scalar, fscalar, dscala
fvector = Tensor('float32', (False, ))
dvector = Tensor('float64', (False, ))
bvector = Tensor('int8', (False,))
ivector = Tensor('int32', (False, ))
lvector = Tensor('int64', (False, ))
def vector(name = None, dtype = 'float64'):
......@@ -270,6 +272,7 @@ vectors, fvectors, dvectors, ivectors, lvectors = _multi(vector, fvector, dvecto
fmatrix = Tensor('float32', (False, False))
dmatrix = Tensor('float64', (False, False))
bmatrix = Tensor('int8', (False, False))
imatrix = Tensor('int32', (False, False))
lmatrix = Tensor('int64', (False, False))
def matrix(name = None, dtype = 'float64'):
......@@ -279,6 +282,7 @@ matrices, fmatrices, dmatrices, imatrices, lmatrices = _multi(matrix, fmatrix, d
frow = Tensor('float32', (True, False))
drow = Tensor('float64', (True, False))
brow = Tensor('int8', (True, False))
irow = Tensor('int32', (True, False))
lrow = Tensor('int64', (True, False))
def row(name = None, dtype = 'float64'):
......@@ -288,6 +292,7 @@ rows, frows, drows, irows, lrows = _multi(row, frow, drow, irow, lrow)
fcol = Tensor('float32', (False, True))
dcol = Tensor('float64', (False, True))
bcol = Tensor('int8', (False, True))
icol = Tensor('int32', (False, True))
lcol = Tensor('int64', (False, True))
def col(name = None, dtype = 'float64'):
......@@ -312,6 +317,18 @@ class _tensor_py_operators:
def __gt__(self,other): return gt(self, other)
def __ge__(self,other): return ge(self, other)
#BITWISE
def __invert__(self): return invert(self)
def __and__(self,other): return and_(self, other)
def __or__(self,other): return or_(self, other)
def __xor__(self,other): return xor(self, other)
def __rand__(self,other): return and_(other,self)
def __ror__(self,other): return or_(other, self)
def __rxor__(self,other): return xor(other, self)
def __iand__(self, other): return and_inplace(self, other)
def __ior__(self, other): return or_inplace(self, other)
def __ixor__(self, other): return xor_inplace(self, other)
#ARITHMETIC - NORMAL
def __add__(self,other): return add(self,other)
def __sub__(self,other): return sub(self,other)
......@@ -460,6 +477,29 @@ def _elemwise(scalar_op, name):
inplace = s2t.Elemwise(inplace_scalar_op, {0: 0})
return straight, inplace
##########################
# Comparison
##########################
lt, lt_inplace = _elemwise(scal.lt, 'lt')
gt, gt_inplace = _elemwise(scal.gt, 'gt')
le, le_inplace = _elemwise(scal.le, 'le')
ge, ge_inplace = _elemwise(scal.ge, 'ge')
##########################
# Bit-wise
##########################
and_, and_inplace = _elemwise(scal.and_, 'and_')
or_, or_inplace = _elemwise(scal.or_, 'or_')
xor, xor_inplace = _elemwise(scal.xor, 'xor')
invert, invert_inplace = _elemwise(scal.invert, 'invert')
##########################
# Math
##########################
_abs, abs_inplace = _elemwise(scal.abs, 'abs')
exp, exp_inplace = _elemwise(scal.exp, 'exp')
neg, neg_inplace = _elemwise(scal.neg, 'neg')
......@@ -475,6 +515,11 @@ cosh, cosh_inplace = _elemwise(scal.cosh, 'cosh')
sinh, sinh_inplace = _elemwise(scal.sinh, 'sinh')
tanh, tanh_inplace = _elemwise(scal.tanh, 'tanh')
##########################
# Misc
##########################
fill, fill_inplace = _elemwise(scal.second, 'fill')
def ones_like(model):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论