提交 cbd1304d authored 作者: Frederic's avatar Frederic

Disable grad/lop/rop/verify_grad on complex.

上级 b1013380
...@@ -132,6 +132,13 @@ def grad_sources_inputs(sources, graph_inputs, warn_type=True): ...@@ -132,6 +132,13 @@ def grad_sources_inputs(sources, graph_inputs, warn_type=True):
#if all output gradients are None, continue #if all output gradients are None, continue
if all(map(lambda x: x is None, g_outputs)): continue if all(map(lambda x: x is None, g_outputs)): continue
#Disable all grad operation on complex. verify_grad don't
#support them and we don't know we want to handle them.
for var in node.inputs + node.outputs:
if (hasattr(var.type, 'dtype') and "complex" in var.type.dtype):
raise Exception("We do not support grad/Rop/Lop/verify_grad"
" on complex.")
output_arg = g_outputs output_arg = g_outputs
input_arg = node.inputs input_arg = node.inputs
......
...@@ -46,6 +46,7 @@ class TestRealImag(unittest.TestCase): ...@@ -46,6 +46,7 @@ class TestRealImag(unittest.TestCase):
assert numpy.all(rval == mval[0]), (rval,mval[0]) assert numpy.all(rval == mval[0]), (rval,mval[0])
assert numpy.all(ival == mval[1]), (ival, mval[1]) assert numpy.all(ival == mval[1]), (ival, mval[1])
@dec.knownfailureif(True,"Complex grads not enabled, see #178")
def test_complex_grads(self): def test_complex_grads(self):
def f(m): def f(m):
c = complex(m[0], m[1]) c = complex(m[0], m[1])
...@@ -103,6 +104,7 @@ class TestRealImag(unittest.TestCase): ...@@ -103,6 +104,7 @@ class TestRealImag(unittest.TestCase):
print e.analytic_grad print e.analytic_grad
raise raise
@dec.knownfailureif(True,"Complex grads not enabled, see #178")
def test_polar_grads(self): def test_polar_grads(self):
def f(m): def f(m):
c = complex_from_polar(abs(m[0]), m[1]) c = complex_from_polar(abs(m[0]), m[1])
...@@ -112,6 +114,7 @@ class TestRealImag(unittest.TestCase): ...@@ -112,6 +114,7 @@ class TestRealImag(unittest.TestCase):
mval = numpy.asarray(rng.randn(2,5)) mval = numpy.asarray(rng.randn(2,5))
utt.verify_grad(f, [mval]) utt.verify_grad(f, [mval])
@dec.knownfailureif(True,"Complex grads not enabled, see #178")
def test_abs_grad(self): def test_abs_grad(self):
def f(m): def f(m):
c = complex(m[0], m[1]) c = complex(m[0], m[1])
......
...@@ -628,6 +628,8 @@ class T_sum_dtype(unittest.TestCase): ...@@ -628,6 +628,8 @@ class T_sum_dtype(unittest.TestCase):
sum_var = x.sum(dtype=output_dtype, axis=axis) sum_var = x.sum(dtype=output_dtype, axis=axis)
assert sum_var.dtype == output_dtype assert sum_var.dtype == output_dtype
if "complex" in input_dtype:
continue
# Check that we can take the gradient # Check that we can take the gradient
grad_var = tensor.grad(sum_var.sum(), x, grad_var = tensor.grad(sum_var.sum(), x,
disconnected_inputs='ignore') disconnected_inputs='ignore')
...@@ -676,6 +678,8 @@ class T_mean_dtype(unittest.TestCase): ...@@ -676,6 +678,8 @@ class T_mean_dtype(unittest.TestCase):
assert mean_var.dtype == sum_dtype, (mean_var.dtype, sum_dtype) assert mean_var.dtype == sum_dtype, (mean_var.dtype, sum_dtype)
# Check that we can take the gradient, when implemented # Check that we can take the gradient, when implemented
if "complex" in mean_var.dtype:
continue
try: try:
grad_var = tensor.grad(mean_var.sum(), x, grad_var = tensor.grad(mean_var.sum(), x,
disconnected_inputs='ignore') disconnected_inputs='ignore')
...@@ -729,6 +733,8 @@ class T_prod_dtype(unittest.TestCase): ...@@ -729,6 +733,8 @@ class T_prod_dtype(unittest.TestCase):
prod_var = x.prod(dtype=output_dtype, axis=axis) prod_var = x.prod(dtype=output_dtype, axis=axis)
assert prod_var.dtype == output_dtype assert prod_var.dtype == output_dtype
if "complex" in output_dtype:
continue
# Check that we can take the gradient # Check that we can take the gradient
grad_var = tensor.grad(prod_var.sum(), x, grad_var = tensor.grad(prod_var.sum(), x,
disconnected_inputs='ignore') disconnected_inputs='ignore')
......
import numpy import numpy
from numpy.testing import dec
import theano import theano
from theano import tensor from theano import tensor
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
...@@ -37,6 +39,7 @@ class TestFourier(utt.InferShapeTester): ...@@ -37,6 +39,7 @@ class TestFourier(utt.InferShapeTester):
[numpy.random.rand(12, 4), 0], [numpy.random.rand(12, 4), 0],
self.op_class) self.op_class)
@dec.knownfailureif(True, "Complex grads not enabled, see #178")
def test_gradient(self): def test_gradient(self):
def fft_test1(a): def fft_test1(a):
return self.op(a, None, None) return self.op(a, None, None)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论