提交 4d32f45e authored 作者: James Bergstra's avatar James Bergstra

comparisons and logical ops almost working...

上级 1c973723
...@@ -84,27 +84,37 @@ class _test_logical(unittest.TestCase): ...@@ -84,27 +84,37 @@ class _test_logical(unittest.TestCase):
def test_or(self): def test_or(self):
x, y, z = ints('xyz') x, y, z = ints('xyz')
fn = gof.DualLinker(Env([x,y], [or_(x, y)])).make_function() fn = gof.DualLinker(Env([x,y], [x|y])).make_function()
for a,b in ((0,1), (0,0), (1,0), (1,1)): for a,b in ((0,1), (0,0), (1,0), (1,1)):
self.failUnless(fn(a,b) == (a or b), (a,b)) self.failUnless(fn(a,b) == (a|b), (a,b))
def test_xor(self): def test_xor(self):
x, y, z = ints('xyz') x, y, z = ints('xyz')
fn = gof.DualLinker(Env([x,y], [xor(x, y)])).make_function() fn = gof.DualLinker(Env([x,y], [x^y])).make_function()
for a,b in ((0,1), (0,0), (1,0), (1,1)): for a,b in ((0,1), (0,0), (1,0), (1,1)):
self.failUnless(fn(a,b) == (operator.xor(a, b)), (a,b)) self.failUnless(fn(a,b) == (a ^ b), (a,b))
def test_and(self): def test_and(self):
x, y, z = ints('xyz') x, y, z = ints('xyz')
fn = gof.DualLinker(Env([x,y], [and_(x, y)])).make_function() fn = gof.DualLinker(Env([x,y], [and_(x, y)])).make_function()
for a,b in ((0,1), (0,0), (1,0), (1,1)): for a,b in ((0,1), (0,0), (1,0), (1,1)):
self.failUnless(fn(a,b) == (a and b), (a,b)) self.failUnless(fn(a,b) == (a & b), (a,b))
x, y, z = ints('xyz')
fn = gof.DualLinker(Env([x,y], [x & y])).make_function()
for a,b in ((0,1), (0,0), (1,0), (1,1)):
self.failUnless(fn(a,b) == (a & b), (a,b))
def test_not(self): def test_not(self):
x, y, z = ints('xyz') x, y, z = ints('xyz')
fn = gof.DualLinker(Env([x,y], [not_(x)])).make_function() fn = gof.DualLinker(Env([x,y], [invert(x)])).make_function()
for a,b in ((0,1), (0,0), (1,0), (1,1)):
self.failUnless(fn(a,b) == ~a, (a,))
x, y, z = ints('xyz')
fn = gof.DualLinker(Env([x,y], [~x])).make_function()
for a,b in ((0,1), (0,0), (1,0), (1,1)): for a,b in ((0,1), (0,0), (1,0), (1,1)):
self.failUnless(fn(a,b) == (not a), (a,)) self.failUnless(fn(a,b) == ~a, (a,))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
......
import traceback import traceback
import operator
from tensor import * from tensor import *
import tensor # for hidden symbols import tensor # for hidden symbols
...@@ -883,6 +885,81 @@ class T_Stack(unittest.TestCase): ...@@ -883,6 +885,81 @@ class T_Stack(unittest.TestCase):
self.failUnless((eval_outputs([s]) == c).all()) self.failUnless((eval_outputs([s]) == c).all())
class _test_comparison(unittest.TestCase):
def test_gt(self):
x, y = fvector(), fvector()
fn = function([x,y], [x > y])
l = numpy.asarray([0.,-1.,1.])
r = numpy.asarray([0.,1.,-1.])
v = fn(l, r)
self.failUnless(numpy.all(v == (l > r)), (v, (l>r)))
def test_lt(self):
x, y = fvector(), fvector()
fn = function([x,y], [x < y])
l = numpy.asarray([0.,-1.,1.])
r = numpy.asarray([0.,1.,-1.])
v = fn(l, r)
self.failUnless(numpy.all(v == (l < r)), (v, (l<r)))
def test_le(self):
x, y = fvector(), fvector()
fn = function([x,y], [x <= y])
l = numpy.asarray([0.,-1.,1.])
r = numpy.asarray([0.,1.,-1.])
v = fn(l, r)
self.failUnless(numpy.all(v == (l <= r)), (v, (l<=r)))
def test_ge(self):
x, y = fvector(), fvector()
fn = function([x,y], [x >= y])
l = numpy.asarray([0.,-1.,1.])
r = numpy.asarray([0.,1.,-1.])
v = fn(l, r)
self.failUnless(numpy.all(v == (l >= r)), (v, (l>=r)))
class _test_bitwise(unittest.TestCase):
def test_or(self):
x, y = bvector(), bvector()
fn = function([x,y], [x|y])
l = numpy.asarray([0,0,1,1])
r = numpy.asarray([0,1,0,1])
v = fn(l, r)
self.failUnless(numpy.all(v == (operator.or_(l, r))), (l, r, v))
def test_xor(self):
x, y = bvector(), bvector()
fn = function([x,y], [x^y])
ix = x
ix ^= y
gn = function([x,y], [ix])
l = numpy.asarray([0,0,1,1])
r = numpy.asarray([0,1,0,1])
v = fn(l, r)
self.failUnless(numpy.all(v == (operator.xor(l, r))), (l, r, v))
print ' '
print l, type(l)
v = gn(l, r)
#test the in-place stuff
print l, type(l)
print v, type(l)
self.failUnless(numpy.all(l == numpy.asarray([0,1,1,0])), l)
def test_and(self):
x, y = bvector(), bvector()
fn = function([x,y], [x&y])
l = numpy.asarray([0,0,1,1])
r = numpy.asarray([0,1,0,1])
v = fn(l, r)
self.failUnless(numpy.all(v == (operator.and_(l, r))), (l, r, v))
def test_inv(self):
x, y = bvector(), bvector()
fn = function([x,y], [~x])
l = numpy.asarray([0,0,1,1])
r = numpy.asarray([0,1,0,1])
v = fn(l, r)
self.failUnless(numpy.all(v == (~l)), (l, r, v))
......
...@@ -183,6 +183,15 @@ class _scalar_py_operators: ...@@ -183,6 +183,15 @@ class _scalar_py_operators:
def __float__(self): return AsInt(self).out def __float__(self): return AsInt(self).out
def __complex__(self): return AsComplex(self).out def __complex__(self): return AsComplex(self).out
#BITWISE
def __invert__(self): return invert(self)
def __and__(self,other): return and_(self, other)
def __or__(self,other): return or_(self, other)
def __xor__(self,other): return xor(self, other)
def __rand__(self,other): return and_(other,self)
def __ror__(self,other): return or_(other, self)
def __rxor__(self,other): return xor(other, self)
#COMPARISONS #COMPARISONS
def __lt__(self,other): return lt(self, other) def __lt__(self,other): return lt(self, other)
def __le__(self,other): return le(self, other) def __le__(self,other): return le(self, other)
...@@ -326,21 +335,18 @@ class UnaryScalarOp(ScalarOp): ...@@ -326,21 +335,18 @@ class UnaryScalarOp(ScalarOp):
class BinaryScalarOp(ScalarOp): class BinaryScalarOp(ScalarOp):
nin = 2 nin = 2
class UnaryLogicalOp(UnaryScalarOp):
def output_types(self, *input_dtypes):
return [int8]
def grad(self, inputs, output_gradients):
return [None]
class BinaryLogicalOp(BinaryScalarOp): ###############
# Comparisons
###############
class LogicalComparison(BinaryScalarOp):
def output_types(self, *input_dtypes): def output_types(self, *input_dtypes):
return [int8] return [int8]
def grad(self, inputs, output_gradients): def grad(self, inputs, output_gradients):
return [None, None] return [None, None]
class LT(LogicalComparison):
class LT(BinaryLogicalOp):
identity = False identity = False
commutative = False commutative = False
associative = False associative = False
...@@ -348,7 +354,7 @@ class LT(BinaryLogicalOp): ...@@ -348,7 +354,7 @@ class LT(BinaryLogicalOp):
return x < y return x < y
lt = LT() lt = LT()
class GT(BinaryLogicalOp): class GT(LogicalComparison):
identity = False identity = False
commutative = False commutative = False
associative = False associative = False
...@@ -356,7 +362,7 @@ class GT(BinaryLogicalOp): ...@@ -356,7 +362,7 @@ class GT(BinaryLogicalOp):
return x > y return x > y
gt = GT() gt = GT()
class LE(BinaryLogicalOp): class LE(LogicalComparison):
identity = False identity = False
commutative = False commutative = False
associative = False associative = False
...@@ -364,7 +370,7 @@ class LE(BinaryLogicalOp): ...@@ -364,7 +370,7 @@ class LE(BinaryLogicalOp):
return x <= y return x <= y
le = LE() le = LE()
class GE(BinaryLogicalOp): class GE(LogicalComparison):
identity = False identity = False
commutative = False commutative = False
associative = False associative = False
...@@ -372,7 +378,7 @@ class GE(BinaryLogicalOp): ...@@ -372,7 +378,7 @@ class GE(BinaryLogicalOp):
return x >= y return x >= y
ge = GE() ge = GE()
class EQ(BinaryLogicalOp): class EQ(LogicalComparison):
identity = False identity = False
commutative = True commutative = True
associative = False associative = False
...@@ -380,38 +386,64 @@ class EQ(BinaryLogicalOp): ...@@ -380,38 +386,64 @@ class EQ(BinaryLogicalOp):
return x == y return x == y
eq = EQ() eq = EQ()
class OR(BinaryLogicalOp): ####################
# BIT-WISE OPERATORS
####################
class UnaryBitOp(UnaryScalarOp):
def output_types(self, *input_types):
for i in input_types[0]:
if i not in (int8, int32, int64):
raise TypeError('input to a BitOp must have type int8, int32 or int 64... not %s' % i)
return upcast_out(*input_types[0])
def grad(self, inputs, output_gradients):
return [None]
class BinaryBitOp(BinaryScalarOp):
def output_types(self, *input_types):
t0, t1 = input_types[0]
for i in input_types[0]:
if i not in (int8, int32, int64):
raise TypeError('input to a BitOp must have type int8, int32 or int 64... not %s' % i)
return upcast_out(*input_types[0])
def grad(self, inputs, output_gradients):
return [None, None]
class OR(BinaryBitOp):
identity = False identity = False
commutative = True commutative = True
associative = False associative = False
def impl(self, x, y): def impl(self, x, y):
return (x or y) return x | y
or_ = OR() or_ = OR()
class XOR(BinaryLogicalOp): class XOR(BinaryBitOp):
identity = False identity = False
commutative = True commutative = True
associative = False associative = False
def impl(self, x, y): def impl(self, x, y):
return operator.xor(x, y) return x ^ y
xor = XOR() xor = XOR()
class AND(BinaryBitOp):
class AND(BinaryLogicalOp):
identity = False identity = False
commutative = True commutative = True
associative = False associative = False
def impl(self, x, y): def impl(self, x, y):
return x and y return x & y
and_ = AND() and_ = AND()
class NOT(UnaryLogicalOp): class Invert(UnaryBitOp):
identity = False identity = False
def impl(self, x): def impl(self, x):
return not x return ~x
not_ = NOT() invert = Invert()
##############
# Arithmetic
##############
class Add(ScalarOp): class Add(ScalarOp):
identity = 0 identity = 0
......
...@@ -252,6 +252,7 @@ def _multi(*fns): ...@@ -252,6 +252,7 @@ def _multi(*fns):
fscalar = Tensor('float32', ()) fscalar = Tensor('float32', ())
dscalar = Tensor('float64', ()) dscalar = Tensor('float64', ())
bscalar = Tensor('int8', ())
iscalar = Tensor('int32', ()) iscalar = Tensor('int32', ())
lscalar = Tensor('int64', ()) lscalar = Tensor('int64', ())
def scalar(name = None, dtype = 'float64'): def scalar(name = None, dtype = 'float64'):
...@@ -261,6 +262,7 @@ scalars, fscalars, dscalars, iscalars, lscalars = _multi(scalar, fscalar, dscala ...@@ -261,6 +262,7 @@ scalars, fscalars, dscalars, iscalars, lscalars = _multi(scalar, fscalar, dscala
fvector = Tensor('float32', (False, )) fvector = Tensor('float32', (False, ))
dvector = Tensor('float64', (False, )) dvector = Tensor('float64', (False, ))
bvector = Tensor('int8', (False,))
ivector = Tensor('int32', (False, )) ivector = Tensor('int32', (False, ))
lvector = Tensor('int64', (False, )) lvector = Tensor('int64', (False, ))
def vector(name = None, dtype = 'float64'): def vector(name = None, dtype = 'float64'):
...@@ -270,6 +272,7 @@ vectors, fvectors, dvectors, ivectors, lvectors = _multi(vector, fvector, dvecto ...@@ -270,6 +272,7 @@ vectors, fvectors, dvectors, ivectors, lvectors = _multi(vector, fvector, dvecto
fmatrix = Tensor('float32', (False, False)) fmatrix = Tensor('float32', (False, False))
dmatrix = Tensor('float64', (False, False)) dmatrix = Tensor('float64', (False, False))
bmatrix = Tensor('int8', (False, False))
imatrix = Tensor('int32', (False, False)) imatrix = Tensor('int32', (False, False))
lmatrix = Tensor('int64', (False, False)) lmatrix = Tensor('int64', (False, False))
def matrix(name = None, dtype = 'float64'): def matrix(name = None, dtype = 'float64'):
...@@ -279,6 +282,7 @@ matrices, fmatrices, dmatrices, imatrices, lmatrices = _multi(matrix, fmatrix, d ...@@ -279,6 +282,7 @@ matrices, fmatrices, dmatrices, imatrices, lmatrices = _multi(matrix, fmatrix, d
frow = Tensor('float32', (True, False)) frow = Tensor('float32', (True, False))
drow = Tensor('float64', (True, False)) drow = Tensor('float64', (True, False))
brow = Tensor('int8', (True, False))
irow = Tensor('int32', (True, False)) irow = Tensor('int32', (True, False))
lrow = Tensor('int64', (True, False)) lrow = Tensor('int64', (True, False))
def row(name = None, dtype = 'float64'): def row(name = None, dtype = 'float64'):
...@@ -288,6 +292,7 @@ rows, frows, drows, irows, lrows = _multi(row, frow, drow, irow, lrow) ...@@ -288,6 +292,7 @@ rows, frows, drows, irows, lrows = _multi(row, frow, drow, irow, lrow)
fcol = Tensor('float32', (False, True)) fcol = Tensor('float32', (False, True))
dcol = Tensor('float64', (False, True)) dcol = Tensor('float64', (False, True))
bcol = Tensor('int8', (False, True))
icol = Tensor('int32', (False, True)) icol = Tensor('int32', (False, True))
lcol = Tensor('int64', (False, True)) lcol = Tensor('int64', (False, True))
def col(name = None, dtype = 'float64'): def col(name = None, dtype = 'float64'):
...@@ -312,6 +317,18 @@ class _tensor_py_operators: ...@@ -312,6 +317,18 @@ class _tensor_py_operators:
def __gt__(self,other): return gt(self, other) def __gt__(self,other): return gt(self, other)
def __ge__(self,other): return ge(self, other) def __ge__(self,other): return ge(self, other)
#BITWISE
def __invert__(self): return invert(self)
def __and__(self,other): return and_(self, other)
def __or__(self,other): return or_(self, other)
def __xor__(self,other): return xor(self, other)
def __rand__(self,other): return and_(other,self)
def __ror__(self,other): return or_(other, self)
def __rxor__(self,other): return xor(other, self)
def __iand__(self, other): return and_inplace(self, other)
def __ior__(self, other): return or_inplace(self, other)
def __ixor__(self, other): return xor_inplace(self, other)
#ARITHMETIC - NORMAL #ARITHMETIC - NORMAL
def __add__(self,other): return add(self,other) def __add__(self,other): return add(self,other)
def __sub__(self,other): return sub(self,other) def __sub__(self,other): return sub(self,other)
...@@ -460,6 +477,29 @@ def _elemwise(scalar_op, name): ...@@ -460,6 +477,29 @@ def _elemwise(scalar_op, name):
inplace = s2t.Elemwise(inplace_scalar_op, {0: 0}) inplace = s2t.Elemwise(inplace_scalar_op, {0: 0})
return straight, inplace return straight, inplace
##########################
# Comparison
##########################
lt, lt_inplace = _elemwise(scal.lt, 'lt')
gt, gt_inplace = _elemwise(scal.gt, 'gt')
le, le_inplace = _elemwise(scal.le, 'le')
ge, ge_inplace = _elemwise(scal.ge, 'ge')
##########################
# Bit-wise
##########################
and_, and_inplace = _elemwise(scal.and_, 'and_')
or_, or_inplace = _elemwise(scal.or_, 'or_')
xor, xor_inplace = _elemwise(scal.xor, 'xor')
invert, invert_inplace = _elemwise(scal.invert, 'invert')
##########################
# Math
##########################
_abs, abs_inplace = _elemwise(scal.abs, 'abs') _abs, abs_inplace = _elemwise(scal.abs, 'abs')
exp, exp_inplace = _elemwise(scal.exp, 'exp') exp, exp_inplace = _elemwise(scal.exp, 'exp')
neg, neg_inplace = _elemwise(scal.neg, 'neg') neg, neg_inplace = _elemwise(scal.neg, 'neg')
...@@ -475,6 +515,11 @@ cosh, cosh_inplace = _elemwise(scal.cosh, 'cosh') ...@@ -475,6 +515,11 @@ cosh, cosh_inplace = _elemwise(scal.cosh, 'cosh')
sinh, sinh_inplace = _elemwise(scal.sinh, 'sinh') sinh, sinh_inplace = _elemwise(scal.sinh, 'sinh')
tanh, tanh_inplace = _elemwise(scal.tanh, 'tanh') tanh, tanh_inplace = _elemwise(scal.tanh, 'tanh')
##########################
# Misc
##########################
fill, fill_inplace = _elemwise(scal.second, 'fill') fill, fill_inplace = _elemwise(scal.second, 'fill')
def ones_like(model): def ones_like(model):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论