提交 574c285f authored 作者: James Bergstra's avatar James Bergstra

added logical and comparison ops to scalar.py

上级 c331e558
......@@ -55,7 +55,56 @@ class _test_composite(unittest.TestCase):
g = Env([x, y, z], c.outputs)
fn = gof.DualLinker(g).make_function()
assert fn(1.0, 2.0, 3.0) == [6.0, 7.0, 0.5]
class _test_logical(unittest.TestCase):
def test_gt(self):
x, y, z = inputs()
fn = gof.DualLinker(Env([x,y], [x > y])).make_function()
for a,b in ((3.,9), (3,0.9), (3,3)):
self.failUnless(fn(a,b) == (a>b))
def test_lt(self):
x, y, z = inputs()
fn = gof.DualLinker(Env([x,y], [x < y])).make_function()
for a,b in ((3.,9), (3,0.9), (3,3)):
self.failUnless(fn(a,b) == (a<b))
def test_le(self):
x, y, z = inputs()
fn = gof.DualLinker(Env([x,y], [x <= y])).make_function()
for a,b in ((3.,9), (3,0.9), (3,3)):
self.failUnless(fn(a,b) == (a<=b))
def test_ge(self):
x, y, z = inputs()
fn = gof.DualLinker(Env([x,y], [x >= y])).make_function()
for a,b in ((3.,9), (3,0.9), (3,3)):
self.failUnless(fn(a,b) == (a>=b))
def test_or(self):
x, y, z = ints('xyz')
fn = gof.DualLinker(Env([x,y], [or_(x, y)])).make_function()
for a,b in ((0,1), (0,0), (1,0), (1,1)):
self.failUnless(fn(a,b) == (a or b), (a,b))
def test_xor(self):
x, y, z = ints('xyz')
fn = gof.DualLinker(Env([x,y], [xor(x, y)])).make_function()
for a,b in ((0,1), (0,0), (1,0), (1,1)):
self.failUnless(fn(a,b) == (operator.xor(a, b)), (a,b))
def test_and(self):
x, y, z = ints('xyz')
fn = gof.DualLinker(Env([x,y], [and_(x, y)])).make_function()
for a,b in ((0,1), (0,0), (1,0), (1,1)):
self.failUnless(fn(a,b) == (a and b), (a,b))
def test_not(self):
x, y, z = ints('xyz')
fn = gof.DualLinker(Env([x,y], [not_(x)])).make_function()
for a,b in ((0,1), (0,0), (1,0), (1,1)):
self.failUnless(fn(a,b) == (not a), (a,))
if __name__ == '__main__':
unittest.main()
......
import numpy
import operator
import math
from copy import copy
import numpy
import gof
from gof import PropertiedType, Op, PropertiedOp, utils, Result, Constant, Type, Apply, Env
from gof.python25 import partial
......@@ -326,6 +326,91 @@ class UnaryScalarOp(ScalarOp):
class BinaryScalarOp(ScalarOp):
nin = 2
class UnaryLogicalOp(UnaryScalarOp):
def output_types(self, *input_dtypes):
return [int8]
def grad(self, inputs, output_gradients):
return [None]
class BinaryLogicalOp(BinaryScalarOp):
def output_types(self, *input_dtypes):
return [int8]
def grad(self, inputs, output_gradients):
return [None, None]
class LT(BinaryLogicalOp):
identity = False
commutative = False
associative = False
def impl(self, x, y):
return x < y
lt = LT()
class GT(BinaryLogicalOp):
identity = False
commutative = False
associative = False
def impl(self, x, y):
return x > y
gt = GT()
class LE(BinaryLogicalOp):
identity = False
commutative = False
associative = False
def impl(self, x, y):
return x <= y
le = LE()
class GE(BinaryLogicalOp):
identity = False
commutative = False
associative = False
def impl(self, x, y):
return x >= y
ge = GE()
class EQ(BinaryLogicalOp):
identity = False
commutative = True
associative = False
def impl(self, x, y):
return x == y
eq = EQ()
class OR(BinaryLogicalOp):
identity = False
commutative = True
associative = False
def impl(self, x, y):
return (x or y)
or_ = OR()
class XOR(BinaryLogicalOp):
identity = False
commutative = True
associative = False
def impl(self, x, y):
return operator.xor(x, y)
xor = XOR()
class AND(BinaryLogicalOp):
identity = False
commutative = True
associative = False
def impl(self, x, y):
return x and y
and_ = AND()
class NOT(UnaryLogicalOp):
identity = False
def impl(self, x):
return not x
not_ = NOT()
class Add(ScalarOp):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论