提交 cbd1304d authored 作者: Frederic's avatar Frederic

Disable grad/lop/rop/verify_grad on complex.

上级 b1013380
......@@ -132,6 +132,13 @@ def grad_sources_inputs(sources, graph_inputs, warn_type=True):
#if all output gradients are None, continue
if all(map(lambda x: x is None, g_outputs)): continue
#Disable all grad operation on complex. verify_grad don't
#support them and we don't know we want to handle them.
for var in node.inputs + node.outputs:
if (hasattr(var.type, 'dtype') and "complex" in var.type.dtype):
raise Exception("We do not support grad/Rop/Lop/verify_grad"
" on complex.")
output_arg = g_outputs
input_arg = node.inputs
......
......@@ -46,6 +46,7 @@ class TestRealImag(unittest.TestCase):
assert numpy.all(rval == mval[0]), (rval,mval[0])
assert numpy.all(ival == mval[1]), (ival, mval[1])
@dec.knownfailureif(True,"Complex grads not enabled, see #178")
def test_complex_grads(self):
def f(m):
c = complex(m[0], m[1])
......@@ -103,6 +104,7 @@ class TestRealImag(unittest.TestCase):
print e.analytic_grad
raise
@dec.knownfailureif(True,"Complex grads not enabled, see #178")
def test_polar_grads(self):
def f(m):
c = complex_from_polar(abs(m[0]), m[1])
......@@ -112,6 +114,7 @@ class TestRealImag(unittest.TestCase):
mval = numpy.asarray(rng.randn(2,5))
utt.verify_grad(f, [mval])
@dec.knownfailureif(True,"Complex grads not enabled, see #178")
def test_abs_grad(self):
def f(m):
c = complex(m[0], m[1])
......
......@@ -628,6 +628,8 @@ class T_sum_dtype(unittest.TestCase):
sum_var = x.sum(dtype=output_dtype, axis=axis)
assert sum_var.dtype == output_dtype
if "complex" in input_dtype:
continue
# Check that we can take the gradient
grad_var = tensor.grad(sum_var.sum(), x,
disconnected_inputs='ignore')
......@@ -676,6 +678,8 @@ class T_mean_dtype(unittest.TestCase):
assert mean_var.dtype == sum_dtype, (mean_var.dtype, sum_dtype)
# Check that we can take the gradient, when implemented
if "complex" in mean_var.dtype:
continue
try:
grad_var = tensor.grad(mean_var.sum(), x,
disconnected_inputs='ignore')
......@@ -729,6 +733,8 @@ class T_prod_dtype(unittest.TestCase):
prod_var = x.prod(dtype=output_dtype, axis=axis)
assert prod_var.dtype == output_dtype
if "complex" in output_dtype:
continue
# Check that we can take the gradient
grad_var = tensor.grad(prod_var.sum(), x,
disconnected_inputs='ignore')
......
import numpy
from numpy.testing import dec
import theano
from theano import tensor
from theano.tests import unittest_tools as utt
......@@ -37,6 +39,7 @@ class TestFourier(utt.InferShapeTester):
[numpy.random.rand(12, 4), 0],
self.op_class)
@dec.knownfailureif(True, "Complex grads not enabled, see #178")
def test_gradient(self):
def fft_test1(a):
return self.op(a, None, None)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论