提交 3f240fd7 authored 作者: Frederic Bastien's avatar Frederic Bastien

--- theano/tensor/tests/test_basic.py

+++ theano/tensor/tests/test_basic.py (reindented) @@ -1852,7 +1852,7 @@ r = numpy.asarray([0.,1.,-1.], dtype=dtype) v = fn(l, r) self.failUnless(numpy.all(v == (l == r)), (v, (l==r))) - + def test_neq(self): for dtype in ['float64', 'float32', 'complex64', 'complex128']: x, y = vector(dtype=dtype), vector(dtype=dtype)
上级 03e45a99
...@@ -635,8 +635,11 @@ class LT(LogicalComparison): ...@@ -635,8 +635,11 @@ class LT(LogicalComparison):
commutative = False commutative = False
associative = False associative = False
def impl(self, x, y): def impl(self, x, y):
return x < y # built-in < don't support complex
return numpy.less(x, y)
def c_code(self, node, name, (x, y), (z, ), sub): def c_code(self, node, name, (x, y), (z, ), sub):
if node.inputs[0].type in complex_types:
raise NotImplementedError()
return "%(z)s = (%(x)s < %(y)s);" % locals() return "%(z)s = (%(x)s < %(y)s);" % locals()
lt = LT() lt = LT()
...@@ -645,7 +648,8 @@ class GT(LogicalComparison): ...@@ -645,7 +648,8 @@ class GT(LogicalComparison):
commutative = False commutative = False
associative = False associative = False
def impl(self, x, y): def impl(self, x, y):
return x > y # built-in > don't support complex
return numpy.greater(x, y)
def c_code(self, node, name, (x, y), (z, ), sub): def c_code(self, node, name, (x, y), (z, ), sub):
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
raise NotImplementedError() raise NotImplementedError()
...@@ -657,7 +661,8 @@ class LE(LogicalComparison): ...@@ -657,7 +661,8 @@ class LE(LogicalComparison):
commutative = False commutative = False
associative = False associative = False
def impl(self, x, y): def impl(self, x, y):
return x <= y # built-in <= don't support complex
return numpy.less_equal(x, y)
def c_code(self, node, name, (x, y), (z, ), sub): def c_code(self, node, name, (x, y), (z, ), sub):
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
raise NotImplementedError() raise NotImplementedError()
...@@ -669,7 +674,8 @@ class GE(LogicalComparison): ...@@ -669,7 +674,8 @@ class GE(LogicalComparison):
commutative = False commutative = False
associative = False associative = False
def impl(self, x, y): def impl(self, x, y):
return x >= y # built-in >= don't support complex
return numpy.greater_equal(x, y)
def c_code(self, node, name, (x, y), (z, ), sub): def c_code(self, node, name, (x, y), (z, ), sub):
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
raise NotImplementedError() raise NotImplementedError()
...@@ -695,6 +701,8 @@ class NEQ(LogicalComparison): ...@@ -695,6 +701,8 @@ class NEQ(LogicalComparison):
def impl(self, x, y): def impl(self, x, y):
return x != y return x != y
def c_code(self, node, name, (x, y), (z, ), sub): def c_code(self, node, name, (x, y), (z, ), sub):
if node.inputs[0].type in complex_types:
raise NotImplementedError()
return "%(z)s = (%(x)s != %(y)s);" % locals() return "%(z)s = (%(x)s != %(y)s);" % locals()
neq = NEQ() neq = NEQ()
...@@ -772,7 +780,7 @@ class UnaryBitOp(UnaryScalarOp): ...@@ -772,7 +780,7 @@ class UnaryBitOp(UnaryScalarOp):
def output_types(self, *input_types): def output_types(self, *input_types):
for i in input_types[0]: for i in input_types[0]:
if i not in (int8, int32, int64): if i not in (int8, int32, int64):
raise TypeError('input to a BitOp must have type int8, int32 or int 64... not %s' % i) raise TypeError('input to a BitOp must have type int8, int32 or int64... not %s' % i)
return upcast_out(*input_types[0]) return upcast_out(*input_types[0])
def grad(self, inputs, output_gradients): def grad(self, inputs, output_gradients):
return [None] return [None]
...@@ -782,7 +790,7 @@ class BinaryBitOp(BinaryScalarOp): ...@@ -782,7 +790,7 @@ class BinaryBitOp(BinaryScalarOp):
t0, t1 = input_types[0] t0, t1 = input_types[0]
for i in input_types[0]: for i in input_types[0]:
if i not in (int8, int32, int64): if i not in (int8, int32, int64):
raise TypeError('input to a BitOp must have type int8, int32 or int 64... not %s' % i) raise TypeError('input to a BitOp must have type int8, int32 or int64... not %s' % i)
return upcast_out(*input_types[0]) return upcast_out(*input_types[0])
def grad(self, inputs, output_gradients): def grad(self, inputs, output_gradients):
return [None, None] return [None, None]
......
...@@ -1809,50 +1809,56 @@ class T_Join_and_Split(unittest.TestCase): ...@@ -1809,50 +1809,56 @@ class T_Join_and_Split(unittest.TestCase):
class test_comparison(unittest.TestCase): class test_comparison(unittest.TestCase):
def test_gt(self): def test_gt(self):
x, y = fvector(), fvector() for dtype in ['float64', 'float32', 'complex64', 'complex128']:
x, y = vector(dtype=dtype), vector(dtype=dtype)
fn = inplace_func([x,y], x > y) fn = inplace_func([x,y], x > y)
l = numpy.asarray([0.,-1.,1.], dtype='float32') l = numpy.asarray([0.,-1.,1.], dtype=dtype)
r = numpy.asarray([0.,1.,-1.], dtype='float32') r = numpy.asarray([0.,1.,-1.], dtype=dtype)
v = fn(l, r) v = fn(l, r)
self.failUnless(numpy.all(v == (l > r)), (v, (l>r))) self.failUnless(numpy.all(v == (l > r)), (v, (l>r)))
def test_lt(self): def test_lt(self):
x, y = fvector(), fvector() for dtype in ['float64', 'float32', 'complex64', 'complex128']:
x, y = vector(dtype=dtype), vector(dtype=dtype)
fn = inplace_func([x,y], x < y) fn = inplace_func([x,y], x < y)
l = numpy.asarray([0.,-1.,1.], dtype='float32') l = numpy.asarray([0.,-1.,1.], dtype=dtype)
r = numpy.asarray([0.,1.,-1.], dtype='float32') r = numpy.asarray([0.,1.,-1.], dtype=dtype)
v = fn(l, r) v = fn(l, r)
self.failUnless(numpy.all(v == (l < r)), (v, (l<r))) self.failUnless(numpy.all(v == (l < r)), (v, (l<r)))
def test_le(self): def test_le(self):
x, y = fvector(), fvector() for dtype in ['float64', 'float32', 'complex64', 'complex128']:
x, y = vector(dtype=dtype), vector(dtype=dtype)
fn = inplace_func([x,y], x <= y) fn = inplace_func([x,y], x <= y)
l = numpy.asarray([0.,-1.,1.], dtype='float32') l = numpy.asarray([0.,-1.,1.], dtype=dtype)
r = numpy.asarray([0.,1.,-1.], dtype='float32') r = numpy.asarray([0.,1.,-1.], dtype=dtype)
v = fn(l, r) v = fn(l, r)
self.failUnless(numpy.all(v == (l <= r)), (v, (l<=r))) self.failUnless(numpy.all(v == (l <= r)), (v, (l<=r)))
def test_ge(self): def test_ge(self):
x, y = fvector(), fvector() for dtype in ['float64', 'float32', 'complex64', 'complex128']:
x, y = vector(dtype=dtype), vector(dtype=dtype)
fn = inplace_func([x,y], x >= y) fn = inplace_func([x,y], x >= y)
l = numpy.asarray([0.,-1.,1.], dtype='float32') l = numpy.asarray([0.,-1.,1.], dtype=dtype)
r = numpy.asarray([0.,1.,-1.], dtype='float32') r = numpy.asarray([0.,1.,-1.], dtype=dtype)
v = fn(l, r) v = fn(l, r)
self.failUnless(numpy.all(v == (l >= r)), (v, (l>=r))) self.failUnless(numpy.all(v == (l >= r)), (v, (l>=r)))
def test_eq(self): def test_eq(self):
x, y = fvector(), fvector() for dtype in ['float64', 'float32', 'complex64', 'complex128']:
x, y = vector(dtype=dtype), vector(dtype=dtype)
fn = inplace_func([x,y], eq(x,y)) fn = inplace_func([x,y], eq(x,y))
l = numpy.asarray([0.,-1.,1.], dtype='float32') l = numpy.asarray([0.,-1.,1.], dtype=dtype)
r = numpy.asarray([0.,1.,-1.], dtype='float32') r = numpy.asarray([0.,1.,-1.], dtype=dtype)
v = fn(l, r) v = fn(l, r)
self.failUnless(numpy.all(v == (l == r)), (v, (l==r))) self.failUnless(numpy.all(v == (l == r)), (v, (l==r)))
def test_neq(self): def test_neq(self):
x, y = fvector(), fvector() for dtype in ['float64', 'float32', 'complex64', 'complex128']:
x, y = vector(dtype=dtype), vector(dtype=dtype)
fn = inplace_func([x,y], neq(x, y)) fn = inplace_func([x,y], neq(x, y))
l = numpy.asarray([0.,-1.,1.], dtype='float32') l = numpy.asarray([0.,-1.,1.], dtype=dtype)
r = numpy.asarray([0.,1.,-1.], dtype='float32') r = numpy.asarray([0.,1.,-1.], dtype=dtype)
v = fn(l, r) v = fn(l, r)
self.failUnless(numpy.all(v == (l != r)), (v, (l!=r))) self.failUnless(numpy.all(v == (l != r)), (v, (l!=r)))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论