提交 72132077 authored 作者: Frederic's avatar Frederic

make some reduction test work with complex.

上级 7af328c7
...@@ -301,6 +301,12 @@ class Scalar(Type): ...@@ -301,6 +301,12 @@ class Scalar(Type):
ret.imag = -this->imag; ret.imag = -this->imag;
return ret; return ret;
} }
bool operator ==(const complex_type &y) const {
return (this->real == y.real) && (this->imag == y.imag);
}
bool operator ==(const npy_float%(nbits)s &y) const {
return (this->real == y) && (this->imag == 0);
}
complex_type operator -(const complex_type &y) const { complex_type operator -(const complex_type &y) const {
complex_type ret; complex_type ret;
ret.real = this->real - y.real; ret.real = this->real - y.real;
......
...@@ -679,16 +679,16 @@ class T_sum_dtype(unittest.TestCase): ...@@ -679,16 +679,16 @@ class T_sum_dtype(unittest.TestCase):
sum_var = x.sum(dtype=output_dtype, axis=axis) sum_var = x.sum(dtype=output_dtype, axis=axis)
assert sum_var.dtype == output_dtype assert sum_var.dtype == output_dtype
f = theano.function([x], sum_var)
data = numpy.random.rand(3, 4) * 10
data = data.astype(input_dtype)
f(data)
if "complex" in input_dtype: if "complex" in input_dtype:
continue continue
# Check that we can take the gradient # Check that we can take the gradient
tensor.grad(sum_var.sum(), x, tensor.grad(sum_var.sum(), x,
disconnected_inputs='ignore') disconnected_inputs='ignore')
idx += 1 idx += 1
f = theano.function([x], sum_var)
data = numpy.random.rand(3, 4) * 10
data = data.astype(input_dtype)
f(data)
def test_sum_custom_acc_dtype(self): def test_sum_custom_acc_dtype(self):
""" """
...@@ -783,7 +783,10 @@ class T_mean_dtype(unittest.TestCase): ...@@ -783,7 +783,10 @@ class T_mean_dtype(unittest.TestCase):
else: else:
assert mean_var.dtype == sum_dtype, ( assert mean_var.dtype == sum_dtype, (
(mean_var.dtype, sum_dtype)) (mean_var.dtype, sum_dtype))
if ("complex" not in sum_dtype and "complex" not in input_dtype): if (('complex' in input_dtype or
'complex' in sum_dtype) and
input_dtype != sum_dtype):
continue
f = theano.function([x], mean_var) f = theano.function([x], mean_var)
data = numpy.random.rand(3, 4) * 10 data = numpy.random.rand(3, 4) * 10
data = data.astype(input_dtype) data = data.astype(input_dtype)
...@@ -873,19 +876,25 @@ class T_prod_dtype(unittest.TestCase): ...@@ -873,19 +876,25 @@ class T_prod_dtype(unittest.TestCase):
x = tensor.matrix(dtype=input_dtype) x = tensor.matrix(dtype=input_dtype)
for output_dtype in imap(str, theano.scalar.all_types): for output_dtype in imap(str, theano.scalar.all_types):
axis = axes[idx % len(axes)] axis = axes[idx % len(axes)]
idx += 1
prod_var = x.prod(dtype=output_dtype, axis=axis) prod_var = x.prod(dtype=output_dtype, axis=axis)
assert prod_var.dtype == output_dtype assert prod_var.dtype == output_dtype
if "complex" in output_dtype or "complex" in input_dtype: if (('complex' in output_dtype or
'complex' in input_dtype) and
input_dtype != output_dtype):
continue continue
# Check that we can take the gradient
tensor.grad(prod_var.sum(), x,
disconnected_inputs='ignore')
f = theano.function([x], prod_var) f = theano.function([x], prod_var)
data = numpy.random.rand(3, 4) * 10 data = numpy.random.rand(3, 4) * 10
data = data.astype(input_dtype) data = data.astype(input_dtype)
f(data) f(data)
idx += 1
if "complex" in output_dtype or "complex" in input_dtype:
continue
# Check that we can take the gradient
tensor.grad(prod_var.sum(), x,
disconnected_inputs='ignore')
def test_prod_custom_acc_dtype(self): def test_prod_custom_acc_dtype(self):
""" """
...@@ -908,15 +917,19 @@ class T_prod_dtype(unittest.TestCase): ...@@ -908,15 +917,19 @@ class T_prod_dtype(unittest.TestCase):
prod_var = x.prod(acc_dtype=acc_dtype, axis=axis) prod_var = x.prod(acc_dtype=acc_dtype, axis=axis)
assert prod_var.owner.op.acc_dtype == acc_dtype assert prod_var.owner.op.acc_dtype == acc_dtype
if "complex" in acc_dtype: if (acc_dtype.startswith('complex') and
input_dtype != acc_dtype):
continue continue
# Check that we can take the gradient
tensor.grad(prod_var.sum(), x,
disconnected_inputs='ignore')
f = theano.function([x], prod_var) f = theano.function([x], prod_var)
data = numpy.random.rand(3, 4) * 10 data = numpy.random.rand(3, 4) * 10
data = data.astype(input_dtype) data = data.astype(input_dtype)
f(data) f(data)
if "complex" in acc_dtype:
continue
# Check that we can take the gradient
tensor.grad(prod_var.sum(), x,
disconnected_inputs='ignore')
else: else:
self.assertRaises(TypeError, self.assertRaises(TypeError,
x.prod, acc_dtype=acc_dtype, axis=axis) x.prod, acc_dtype=acc_dtype, axis=axis)
...@@ -967,7 +980,7 @@ class T_prod_without_zeros_dtype(unittest.TestCase): ...@@ -967,7 +980,7 @@ class T_prod_without_zeros_dtype(unittest.TestCase):
if 'complex' in dtype: if 'complex' in dtype:
continue continue
f = theano.function([x], p) f = theano.function([x], p)
data = numpy.random.rand(3, 4) * 10 data = numpy.random.rand(2, 3) * 3
data = data.astype(dtype) data = data.astype(dtype)
f(data) f(data)
...@@ -985,13 +998,14 @@ class T_prod_without_zeros_dtype(unittest.TestCase): ...@@ -985,13 +998,14 @@ class T_prod_without_zeros_dtype(unittest.TestCase):
prod_woz_var = ProdWithoutZeros( prod_woz_var = ProdWithoutZeros(
axis=axis, dtype=output_dtype)(x) axis=axis, dtype=output_dtype)(x)
assert prod_woz_var.dtype == output_dtype assert prod_woz_var.dtype == output_dtype
if ('complex' not in input_dtype and idx += 1
'complex' not in output_dtype): if ('complex' in output_dtype or
'complex' in input_dtype):
continue
f = theano.function([x], prod_woz_var) f = theano.function([x], prod_woz_var)
data = numpy.random.rand(3, 4) * 10 data = numpy.random.rand(2, 3) * 3
data = data.astype(input_dtype) data = data.astype(input_dtype)
f(data) f(data)
idx += 1
def test_prod_without_zeros_custom_acc_dtype(self): def test_prod_without_zeros_custom_acc_dtype(self):
""" """
...@@ -1015,7 +1029,8 @@ class T_prod_without_zeros_dtype(unittest.TestCase): ...@@ -1015,7 +1029,8 @@ class T_prod_without_zeros_dtype(unittest.TestCase):
axis=axis, acc_dtype=acc_dtype)(x) axis=axis, acc_dtype=acc_dtype)(x)
assert prod_woz_var.owner.op.acc_dtype == acc_dtype assert prod_woz_var.owner.op.acc_dtype == acc_dtype
if acc_dtype.startswith('complex') and input_dtype != acc_dtype: if (acc_dtype.startswith('complex') and
input_dtype != acc_dtype):
continue continue
f = theano.function([x], prod_woz_var) f = theano.function([x], prod_woz_var)
data = numpy.random.rand(2, 3) * 3 data = numpy.random.rand(2, 3) * 3
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论