提交 fe500996 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Don't allow operations that are deprecated in numpy.

上级 c6b619c9
...@@ -690,6 +690,13 @@ def upcast_out(*types): ...@@ -690,6 +690,13 @@ def upcast_out(*types):
def upcast_out_nobool(*types): def upcast_out_nobool(*types):
type = upcast_out(*types)
if type[0] == bool:
raise TypeError("bool output not supported")
return type
def upcast_out_min8(*types):
type = upcast_out(*types) type = upcast_out(*types)
if type[0] == bool: if type[0] == bool:
return int8, return int8,
...@@ -719,6 +726,12 @@ def same_out(type): ...@@ -719,6 +726,12 @@ def same_out(type):
def same_out_nobool(type): def same_out_nobool(type):
if type == bool:
raise TypeError("bool input not supported")
return type,
def same_out_min8(type):
if type == bool: if type == bool:
return int8, return int8,
return type, return type,
...@@ -1585,9 +1598,6 @@ class Sub(BinaryScalarOp): ...@@ -1585,9 +1598,6 @@ class Sub(BinaryScalarOp):
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
(x, y) = inputs (x, y) = inputs
(z,) = outputs (z,) = outputs
if node.outputs[0].type == bool:
# xor
return "%(z)s = (%(x)s || %(y)s) && !(%(x)s && %(y)s);" % locals()
return "%(z)s = %(x)s - %(y)s;" % locals() return "%(z)s = %(x)s - %(y)s;" % locals()
def grad(self, inputs, gout): def grad(self, inputs, gout):
...@@ -1604,7 +1614,7 @@ class Sub(BinaryScalarOp): ...@@ -1604,7 +1614,7 @@ class Sub(BinaryScalarOp):
second_part = -gz second_part = -gz
return first_part, second_part return first_part, second_part
sub = Sub(upcast_out, name='sub') sub = Sub(upcast_out_nobool, name='sub')
def int_or_true_div(x_discrete, y_discrete): def int_or_true_div(x_discrete, y_discrete):
...@@ -1970,7 +1980,7 @@ class Pow(BinaryScalarOp): ...@@ -1970,7 +1980,7 @@ class Pow(BinaryScalarOp):
raise theano.gof.utils.MethodNotDefined() raise theano.gof.utils.MethodNotDefined()
pow = Pow(upcast_out_nobool, name='pow') pow = Pow(upcast_out_min8, name='pow')
class Clip(ScalarOp): class Clip(ScalarOp):
...@@ -2468,10 +2478,8 @@ class Neg(UnaryScalarOp): ...@@ -2468,10 +2478,8 @@ class Neg(UnaryScalarOp):
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
(x,) = inputs (x,) = inputs
(z,) = outputs (z,) = outputs
if node.outputs[0].type == bool:
return "%(z)s = !%(x)s;" % locals()
return "%(z)s = -%(x)s;" % locals() return "%(z)s = -%(x)s;" % locals()
neg = Neg(same_out, name='neg') neg = Neg(same_out_nobool, name='neg')
pprint.assign(add, printing.OperatorPrinter('+', -2, 'either')) pprint.assign(add, printing.OperatorPrinter('+', -2, 'either'))
pprint.assign(mul, printing.OperatorPrinter('*', -1, 'either')) pprint.assign(mul, printing.OperatorPrinter('*', -1, 'either'))
...@@ -3461,7 +3469,7 @@ class Conj(UnaryScalarOp): ...@@ -3461,7 +3469,7 @@ class Conj(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return numpy.conj(x) return numpy.conj(x)
conj = Conj(same_out_nobool, name='conj') conj = Conj(same_out_min8, name='conj')
class ComplexFromPolar(BinaryScalarOp): class ComplexFromPolar(BinaryScalarOp):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论