提交 64b59c00 authored 作者: Iban Harlouchet's avatar Iban Harlouchet

Additional corrections in scalar/basic.py

上级 94bdc43c
...@@ -420,12 +420,12 @@ class Scalar(Type): ...@@ -420,12 +420,12 @@ class Scalar(Type):
{ this->real=y.real; this->imag=y.imag; return *this; } { this->real=y.real; this->imag=y.imag; return *this; }
''' % dict(mytype=mytype, othertype=othertype) ''' % dict(mytype=mytype, othertype=othertype)
operator_eq = ''.join(operator_eq_real(ctype, rtype) operator_eq = (''.join(operator_eq_real(ctype, rtype)
for ctype in cplx_types for ctype in cplx_types
for rtype in real_types) \ for rtype in real_types) +
+ ''.join(operator_eq_cplx(ctype1, ctype2) ''.join(operator_eq_cplx(ctype1, ctype2)
for ctype1 in cplx_types for ctype1 in cplx_types
for ctype2 in cplx_types) for ctype2 in cplx_types))
# We are not using C++ generic templating here, because this would # We are not using C++ generic templating here, because this would
# generate two different functions for adding a complex64 and a # generate two different functions for adding a complex64 and a
...@@ -472,12 +472,12 @@ class Scalar(Type): ...@@ -472,12 +472,12 @@ class Scalar(Type):
for ctype in cplx_types for ctype in cplx_types
for rtype in real_types) for rtype in real_types)
return template % dict(nbits=64, half_nbits=32) \ return (template % dict(nbits=64, half_nbits=32) +
+ template % dict(nbits=128, half_nbits=64) \ template % dict(nbits=128, half_nbits=64) +
+ operator_eq \ operator_eq +
+ operator_plus \ operator_plus +
+ operator_minus \ operator_minus +
+ operator_mul operator_mul)
else: else:
return "" return ""
...@@ -890,9 +890,9 @@ class ScalarOp(Op): ...@@ -890,9 +890,9 @@ class ScalarOp(Op):
self.__class__.__name__) self.__class__.__name__)
def __eq__(self, other): def __eq__(self, other):
test = type(self) == type(other) \ test = (type(self) == type(other) and
and getattr(self, 'output_types_preference', None) \ getattr(self, 'output_types_preference', None) ==
== getattr(other, 'output_types_preference', None) getattr(other, 'output_types_preference', None))
return test return test
def __hash__(self): def __hash__(self):
...@@ -1189,8 +1189,8 @@ class InRange(LogicalComparison): ...@@ -1189,8 +1189,8 @@ class InRange(LogicalComparison):
def get_grad(self, elem): def get_grad(self, elem):
if elem.type in complex_types: if elem.type in complex_types:
msg = "No gradient implemented for complex numbers in\ msg = ("No gradient implemented for complex numbers in "
class scalar.basic.InRange" "class scalar.basic.InRange")
raise NotImplementedError(msg) raise NotImplementedError(msg)
elif elem.type in discrete_types: elif elem.type in discrete_types:
return elem.zeros_like().astype(theano.config.floatX) return elem.zeros_like().astype(theano.config.floatX)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论