提交 03e45a99 authored 作者: Frederic Bastien's avatar Frederic Bastien

fix crash on elemwise with complex* type. Disable some c_code for complex.

上级 82ca7c85
......@@ -431,6 +431,11 @@ complexs128 = _multi(complex128)
def upcast_out(*types):
return Scalar(dtype = Scalar.upcast(*types)),
def upcast_out_no_complex(*types):
if any([type not in float_types for type in types]):
raise TypeError('complex type are supported')
return Scalar(dtype = Scalar.upcast(*types)),
def same_out(type):
return type,
......@@ -481,6 +486,14 @@ def upgrade_to_float(*types):
int32: float64,
int64: float64}
return Scalar(Scalar.upcast(*[conv.get(type, type) for type in types])),
def upgrade_to_float_no_complex(*types):
"""
don't accept complex, otherwise call upgrade_to_float().
"""
for type in types:
if type in complex_types:
raise TypeError('complex argument not supported')
return upgrade_to_float(*types)
def same_out_nocomplex(type):
if type in complex_types:
raise TypeError('complex argument not supported')
......@@ -634,6 +647,8 @@ class GT(LogicalComparison):
def impl(self, x, y):
return x > y
def c_code(self, node, name, (x, y), (z, ), sub):
if node.inputs[0].type in complex_types:
raise NotImplementedError()
return "%(z)s = (%(x)s > %(y)s);" % locals()
gt = GT()
......@@ -644,6 +659,8 @@ class LE(LogicalComparison):
def impl(self, x, y):
return x <= y
def c_code(self, node, name, (x, y), (z, ), sub):
if node.inputs[0].type in complex_types:
raise NotImplementedError()
return "%(z)s = (%(x)s <= %(y)s);" % locals()
le = LE()
......@@ -654,6 +671,8 @@ class GE(LogicalComparison):
def impl(self, x, y):
return x >= y
def c_code(self, node, name, (x, y), (z, ), sub):
if node.inputs[0].type in complex_types:
raise NotImplementedError()
return "%(z)s = (%(x)s >= %(y)s);" % locals()
ge = GE()
......@@ -664,6 +683,8 @@ class EQ(LogicalComparison):
def impl(self, x, y):
return x == y
def c_code(self, node, name, (x, y), (z, ), sub):
if node.inputs[0].type in complex_types:
raise NotImplementedError()
return "%(z)s = (%(x)s == %(y)s);" % locals()
eq = EQ()
......@@ -806,8 +827,11 @@ class Maximum(BinaryScalarOp):
commutative = True
associative = True
def impl(self, *inputs):
return max(inputs)
# The built-in max function don't support complex type
return numpy.maximum(*inputs)
def c_code(self, node, name, (x,y), (z, ), sub):
if any([i.type in complex_types for i in node.inputs]):
raise NotImplementedError()
return "%(z)s = ((%(y)s)>(%(x)s)? (%(y)s):(%(x)s));" %locals()
def grad(self, (x, y), (gz, )):
......@@ -826,8 +850,11 @@ class Minimum(BinaryScalarOp):
commutative = True
associative = True
def impl(self, *inputs):
return min(inputs)
# The built-in min function don't support complex type
return numpy.minimum(*inputs)
def c_code(self, node, name, (x,y), (z, ), sub):
if any([i.type in complex_types for i in node.inputs]):
raise NotImplementedError()
return "%(z)s = ((%(y)s)<(%(x)s)? (%(y)s):(%(x)s));" %locals()
def grad(self, (x, y), (gz, )):
......
#definition theano.scalar op that have their python implementation taked from scipy
#as scipy is not always available, we put threat them separatly
from theano.scalar.basic import UnaryScalarOp,exp,sqrt,upgrade_to_float,complex_types,float_types,upcast
from theano.scalar.basic import UnaryScalarOp,exp,upgrade_to_float,upgrade_to_float_no_complex,complex_types,float_types,upcast
import numpy
imported_scipy_special = False
......@@ -49,4 +49,6 @@ class Erfc(UnaryScalarOp):
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = erfc(%(x)s);" % locals()
erfc = Erfc(upgrade_to_float, name = 'erfc')
# scipy.special.erfc don't support complex. Why?
erfc = Erfc(upgrade_to_float_no_complex, name = 'erfc')
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论