提交 a81fbe8a authored 作者: abergeron's avatar abergeron

Merge pull request #777 from nouiz/fix_test

Don't test mixing complex output sum with scalar input since the gradient will have to do a complex->scalar cast and it is not well-defined.
...@@ -514,6 +514,8 @@ def local_gpu_ger(node): ...@@ -514,6 +514,8 @@ def local_gpu_ger(node):
tensor.blas_c.CGer(destructive=False): gpu_ger_no_inplace, tensor.blas_c.CGer(destructive=False): gpu_ger_no_inplace,
tensor.blas.Ger(destructive=True): gpu_ger_no_inplace, tensor.blas.Ger(destructive=True): gpu_ger_no_inplace,
tensor.blas.Ger(destructive=False): gpu_ger_no_inplace, tensor.blas.Ger(destructive=False): gpu_ger_no_inplace,
tensor.blas_scipy.ScipyGer(destructive=True): gpu_ger_no_inplace,
tensor.blas_scipy.ScipyGer(destructive=False): gpu_ger_no_inplace,
} }
if node.op == gpu_from_host: if node.op == gpu_from_host:
host_input = node.inputs[0] host_input = node.inputs[0]
......
...@@ -1399,7 +1399,9 @@ class GemmOptimizer(Optimizer): ...@@ -1399,7 +1399,9 @@ class GemmOptimizer(Optimizer):
zip(node.outputs, new_outputs), zip(node.outputs, new_outputs),
[old_dot22], [old_dot22],
reason='GemmOptimizer', reason='GemmOptimizer',
warn=not self.warned #For now we disable the warning as we know case
#that we need to fix.
warn=False, # warn=not self.warned
) )
did_something = True did_something = True
nb_replacement += 1 nb_replacement += 1
......
...@@ -618,6 +618,13 @@ class T_sum_dtype(unittest.TestCase): ...@@ -618,6 +618,13 @@ class T_sum_dtype(unittest.TestCase):
for input_dtype in imap(str, theano.scalar.all_types): for input_dtype in imap(str, theano.scalar.all_types):
x = tensor.matrix(dtype=input_dtype) x = tensor.matrix(dtype=input_dtype)
for output_dtype in imap(str, theano.scalar.all_types): for output_dtype in imap(str, theano.scalar.all_types):
# If the output is a complex, the gradient of the sum will
# cast the complex to the input dtype. We can't call the normal
# cast on a complex to a not complex as this is ambiguous.
if (not input_dtype.startswith('complex') and
output_dtype.startswith('complex')):
continue
axis = axes[idx % len(axes)] axis = axes[idx % len(axes)]
# If output_dtype would force a downcast, we expect a TypeError # If output_dtype would force a downcast, we expect a TypeError
# We always allow int/uint inputs with float/complex outputs. # We always allow int/uint inputs with float/complex outputs.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论