提交 03e45a99 authored 作者: Frederic Bastien's avatar Frederic Bastien

fix crash on elemwise with complex* type. Disable some c_code for complex.

上级 82ca7c85
...@@ -431,6 +431,11 @@ complexs128 = _multi(complex128) ...@@ -431,6 +431,11 @@ complexs128 = _multi(complex128)
def upcast_out(*types): def upcast_out(*types):
return Scalar(dtype = Scalar.upcast(*types)),
def upcast_out_no_complex(*types):
if any([type not in float_types for type in types]):
raise TypeError('complex type are supported')
return Scalar(dtype = Scalar.upcast(*types)), return Scalar(dtype = Scalar.upcast(*types)),
def same_out(type): def same_out(type):
return type, return type,
...@@ -481,6 +486,14 @@ def upgrade_to_float(*types): ...@@ -481,6 +486,14 @@ def upgrade_to_float(*types):
int32: float64, int32: float64,
int64: float64} int64: float64}
return Scalar(Scalar.upcast(*[conv.get(type, type) for type in types])), return Scalar(Scalar.upcast(*[conv.get(type, type) for type in types])),
def upgrade_to_float_no_complex(*types):
"""
don't accept complex, otherwise call upgrade_to_float().
"""
for type in types:
if type in complex_types:
raise TypeError('complex argument not supported')
return upgrade_to_float(*types)
def same_out_nocomplex(type): def same_out_nocomplex(type):
if type in complex_types: if type in complex_types:
raise TypeError('complex argument not supported') raise TypeError('complex argument not supported')
...@@ -634,6 +647,8 @@ class GT(LogicalComparison): ...@@ -634,6 +647,8 @@ class GT(LogicalComparison):
def impl(self, x, y): def impl(self, x, y):
return x > y return x > y
def c_code(self, node, name, (x, y), (z, ), sub): def c_code(self, node, name, (x, y), (z, ), sub):
if node.inputs[0].type in complex_types:
raise NotImplementedError()
return "%(z)s = (%(x)s > %(y)s);" % locals() return "%(z)s = (%(x)s > %(y)s);" % locals()
gt = GT() gt = GT()
...@@ -644,6 +659,8 @@ class LE(LogicalComparison): ...@@ -644,6 +659,8 @@ class LE(LogicalComparison):
def impl(self, x, y): def impl(self, x, y):
return x <= y return x <= y
def c_code(self, node, name, (x, y), (z, ), sub): def c_code(self, node, name, (x, y), (z, ), sub):
if node.inputs[0].type in complex_types:
raise NotImplementedError()
return "%(z)s = (%(x)s <= %(y)s);" % locals() return "%(z)s = (%(x)s <= %(y)s);" % locals()
le = LE() le = LE()
...@@ -654,6 +671,8 @@ class GE(LogicalComparison): ...@@ -654,6 +671,8 @@ class GE(LogicalComparison):
def impl(self, x, y): def impl(self, x, y):
return x >= y return x >= y
def c_code(self, node, name, (x, y), (z, ), sub): def c_code(self, node, name, (x, y), (z, ), sub):
if node.inputs[0].type in complex_types:
raise NotImplementedError()
return "%(z)s = (%(x)s >= %(y)s);" % locals() return "%(z)s = (%(x)s >= %(y)s);" % locals()
ge = GE() ge = GE()
...@@ -664,6 +683,8 @@ class EQ(LogicalComparison): ...@@ -664,6 +683,8 @@ class EQ(LogicalComparison):
def impl(self, x, y): def impl(self, x, y):
return x == y return x == y
def c_code(self, node, name, (x, y), (z, ), sub): def c_code(self, node, name, (x, y), (z, ), sub):
if node.inputs[0].type in complex_types:
raise NotImplementedError()
return "%(z)s = (%(x)s == %(y)s);" % locals() return "%(z)s = (%(x)s == %(y)s);" % locals()
eq = EQ() eq = EQ()
...@@ -806,8 +827,11 @@ class Maximum(BinaryScalarOp): ...@@ -806,8 +827,11 @@ class Maximum(BinaryScalarOp):
commutative = True commutative = True
associative = True associative = True
def impl(self, *inputs): def impl(self, *inputs):
return max(inputs) # The built-in max function don't support complex type
return numpy.maximum(*inputs)
def c_code(self, node, name, (x,y), (z, ), sub): def c_code(self, node, name, (x,y), (z, ), sub):
if any([i.type in complex_types for i in node.inputs]):
raise NotImplementedError()
return "%(z)s = ((%(y)s)>(%(x)s)? (%(y)s):(%(x)s));" %locals() return "%(z)s = ((%(y)s)>(%(x)s)? (%(y)s):(%(x)s));" %locals()
def grad(self, (x, y), (gz, )): def grad(self, (x, y), (gz, )):
...@@ -826,8 +850,11 @@ class Minimum(BinaryScalarOp): ...@@ -826,8 +850,11 @@ class Minimum(BinaryScalarOp):
commutative = True commutative = True
associative = True associative = True
def impl(self, *inputs): def impl(self, *inputs):
return min(inputs) # The built-in min function don't support complex type
return numpy.minimum(*inputs)
def c_code(self, node, name, (x,y), (z, ), sub): def c_code(self, node, name, (x,y), (z, ), sub):
if any([i.type in complex_types for i in node.inputs]):
raise NotImplementedError()
return "%(z)s = ((%(y)s)<(%(x)s)? (%(y)s):(%(x)s));" %locals() return "%(z)s = ((%(y)s)<(%(x)s)? (%(y)s):(%(x)s));" %locals()
def grad(self, (x, y), (gz, )): def grad(self, (x, y), (gz, )):
......
#definition theano.scalar op that have their python implementation taked from scipy #definition theano.scalar op that have their python implementation taked from scipy
#as scipy is not always available, we put threat them separatly #as scipy is not always available, we put threat them separatly
from theano.scalar.basic import UnaryScalarOp,exp,sqrt,upgrade_to_float,complex_types,float_types,upcast from theano.scalar.basic import UnaryScalarOp,exp,upgrade_to_float,upgrade_to_float_no_complex,complex_types,float_types,upcast
import numpy import numpy
imported_scipy_special = False imported_scipy_special = False
...@@ -49,4 +49,6 @@ class Erfc(UnaryScalarOp): ...@@ -49,4 +49,6 @@ class Erfc(UnaryScalarOp):
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type) raise NotImplementedError('type not supported', type)
return "%(z)s = erfc(%(x)s);" % locals() return "%(z)s = erfc(%(x)s);" % locals()
erfc = Erfc(upgrade_to_float, name = 'erfc')
# scipy.special.erfc don't support complex. Why?
erfc = Erfc(upgrade_to_float_no_complex, name = 'erfc')
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论