提交 3f240fd7 authored 作者: Frederic Bastien's avatar Frederic Bastien

--- theano/tensor/tests/test_basic.py

+++ theano/tensor/tests/test_basic.py (reindented) @@ -1852,7 +1852,7 @@ r = numpy.asarray([0.,1.,-1.], dtype=dtype) v = fn(l, r) self.failUnless(numpy.all(v == (l == r)), (v, (l==r))) - + def test_neq(self): for dtype in ['float64', 'float32', 'complex64', 'complex128']: x, y = vector(dtype=dtype), vector(dtype=dtype)
上级 03e45a99
......@@ -635,8 +635,11 @@ class LT(LogicalComparison):
commutative = False
associative = False
def impl(self, x, y):
return x < y
# built-in < don't support complex
return numpy.less(x, y)
def c_code(self, node, name, (x, y), (z, ), sub):
if node.inputs[0].type in complex_types:
raise NotImplementedError()
return "%(z)s = (%(x)s < %(y)s);" % locals()
lt = LT()
......@@ -645,7 +648,8 @@ class GT(LogicalComparison):
commutative = False
associative = False
def impl(self, x, y):
return x > y
# built-in > don't support complex
return numpy.greater(x, y)
def c_code(self, node, name, (x, y), (z, ), sub):
if node.inputs[0].type in complex_types:
raise NotImplementedError()
......@@ -657,7 +661,8 @@ class LE(LogicalComparison):
commutative = False
associative = False
def impl(self, x, y):
return x <= y
# built-in <= don't support complex
return numpy.less_equal(x, y)
def c_code(self, node, name, (x, y), (z, ), sub):
if node.inputs[0].type in complex_types:
raise NotImplementedError()
......@@ -669,7 +674,8 @@ class GE(LogicalComparison):
commutative = False
associative = False
def impl(self, x, y):
return x >= y
# built-in >= don't support complex
return numpy.greater_equal(x, y)
def c_code(self, node, name, (x, y), (z, ), sub):
if node.inputs[0].type in complex_types:
raise NotImplementedError()
......@@ -695,6 +701,8 @@ class NEQ(LogicalComparison):
def impl(self, x, y):
return x != y
def c_code(self, node, name, (x, y), (z, ), sub):
if node.inputs[0].type in complex_types:
raise NotImplementedError()
return "%(z)s = (%(x)s != %(y)s);" % locals()
neq = NEQ()
......@@ -772,7 +780,7 @@ class UnaryBitOp(UnaryScalarOp):
def output_types(self, *input_types):
for i in input_types[0]:
if i not in (int8, int32, int64):
raise TypeError('input to a BitOp must have type int8, int32 or int 64... not %s' % i)
raise TypeError('input to a BitOp must have type int8, int32 or int64... not %s' % i)
return upcast_out(*input_types[0])
def grad(self, inputs, output_gradients):
return [None]
......@@ -782,7 +790,7 @@ class BinaryBitOp(BinaryScalarOp):
t0, t1 = input_types[0]
for i in input_types[0]:
if i not in (int8, int32, int64):
raise TypeError('input to a BitOp must have type int8, int32 or int 64... not %s' % i)
raise TypeError('input to a BitOp must have type int8, int32 or int64... not %s' % i)
return upcast_out(*input_types[0])
def grad(self, inputs, output_gradients):
return [None, None]
......
......@@ -1809,52 +1809,58 @@ class T_Join_and_Split(unittest.TestCase):
class test_comparison(unittest.TestCase):
def test_gt(self):
x, y = fvector(), fvector()
fn = inplace_func([x,y], x > y)
l = numpy.asarray([0.,-1.,1.], dtype='float32')
r = numpy.asarray([0.,1.,-1.], dtype='float32')
v = fn(l, r)
self.failUnless(numpy.all(v == (l > r)), (v, (l>r)))
for dtype in ['float64', 'float32', 'complex64', 'complex128']:
x, y = vector(dtype=dtype), vector(dtype=dtype)
fn = inplace_func([x,y], x > y)
l = numpy.asarray([0.,-1.,1.], dtype=dtype)
r = numpy.asarray([0.,1.,-1.], dtype=dtype)
v = fn(l, r)
self.failUnless(numpy.all(v == (l > r)), (v, (l>r)))
def test_lt(self):
x, y = fvector(), fvector()
fn = inplace_func([x,y], x < y)
l = numpy.asarray([0.,-1.,1.], dtype='float32')
r = numpy.asarray([0.,1.,-1.], dtype='float32')
v = fn(l, r)
self.failUnless(numpy.all(v == (l < r)), (v, (l<r)))
for dtype in ['float64', 'float32', 'complex64', 'complex128']:
x, y = vector(dtype=dtype), vector(dtype=dtype)
fn = inplace_func([x,y], x < y)
l = numpy.asarray([0.,-1.,1.], dtype=dtype)
r = numpy.asarray([0.,1.,-1.], dtype=dtype)
v = fn(l, r)
self.failUnless(numpy.all(v == (l < r)), (v, (l<r)))
def test_le(self):
x, y = fvector(), fvector()
fn = inplace_func([x,y], x <= y)
l = numpy.asarray([0.,-1.,1.], dtype='float32')
r = numpy.asarray([0.,1.,-1.], dtype='float32')
v = fn(l, r)
self.failUnless(numpy.all(v == (l <= r)), (v, (l<=r)))
for dtype in ['float64', 'float32', 'complex64', 'complex128']:
x, y = vector(dtype=dtype), vector(dtype=dtype)
fn = inplace_func([x,y], x <= y)
l = numpy.asarray([0.,-1.,1.], dtype=dtype)
r = numpy.asarray([0.,1.,-1.], dtype=dtype)
v = fn(l, r)
self.failUnless(numpy.all(v == (l <= r)), (v, (l<=r)))
def test_ge(self):
x, y = fvector(), fvector()
fn = inplace_func([x,y], x >= y)
l = numpy.asarray([0.,-1.,1.], dtype='float32')
r = numpy.asarray([0.,1.,-1.], dtype='float32')
v = fn(l, r)
self.failUnless(numpy.all(v == (l >= r)), (v, (l>=r)))
for dtype in ['float64', 'float32', 'complex64', 'complex128']:
x, y = vector(dtype=dtype), vector(dtype=dtype)
fn = inplace_func([x,y], x >= y)
l = numpy.asarray([0.,-1.,1.], dtype=dtype)
r = numpy.asarray([0.,1.,-1.], dtype=dtype)
v = fn(l, r)
self.failUnless(numpy.all(v == (l >= r)), (v, (l>=r)))
def test_eq(self):
x, y = fvector(), fvector()
fn = inplace_func([x,y], eq(x,y))
l = numpy.asarray([0.,-1.,1.], dtype='float32')
r = numpy.asarray([0.,1.,-1.], dtype='float32')
v = fn(l, r)
self.failUnless(numpy.all(v == (l == r)), (v, (l==r)))
for dtype in ['float64', 'float32', 'complex64', 'complex128']:
x, y = vector(dtype=dtype), vector(dtype=dtype)
fn = inplace_func([x,y], eq(x,y))
l = numpy.asarray([0.,-1.,1.], dtype=dtype)
r = numpy.asarray([0.,1.,-1.], dtype=dtype)
v = fn(l, r)
self.failUnless(numpy.all(v == (l == r)), (v, (l==r)))
def test_neq(self):
x, y = fvector(), fvector()
fn = inplace_func([x,y], neq(x, y))
l = numpy.asarray([0.,-1.,1.], dtype='float32')
r = numpy.asarray([0.,1.,-1.], dtype='float32')
v = fn(l, r)
self.failUnless(numpy.all(v == (l != r)), (v, (l!=r)))
for dtype in ['float64', 'float32', 'complex64', 'complex128']:
x, y = vector(dtype=dtype), vector(dtype=dtype)
fn = inplace_func([x,y], neq(x, y))
l = numpy.asarray([0.,-1.,1.], dtype=dtype)
r = numpy.asarray([0.,1.,-1.], dtype=dtype)
v = fn(l, r)
self.failUnless(numpy.all(v == (l != r)), (v, (l!=r)))
class test_bitwise(unittest.TestCase):
def test_or(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论