提交 15cf3bc8 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

fixed grad of Mul,Sgn,Ceil, and Floor

上级 18d6bd2f
...@@ -1184,26 +1184,25 @@ class Mul(ScalarOp): ...@@ -1184,26 +1184,25 @@ class Mul(ScalarOp):
' expected gz type to be complex, got gz with type '+\ ' expected gz type to be complex, got gz with type '+\
str(gz.type)) str(gz.type))
if output_type in discrete_types:
return [ipt.zeros_like.astype(theano.config.floatX)
for ipt in inputs]
for input in inputs: for input in inputs:
if input.type in continuous_types: if gz.type in complex_types:
if gz.type in complex_types: # zr+zi = (xr + xi)(yr + yi)
# zr+zi = (xr + xi)(yr + yi) # zr+zi = (xr*yr - xi*yi) + (xr yi + xi yr )
# zr+zi = (xr*yr - xi*yi) + (xr yi + xi yr ) otherprod = mul(*(utils.difference(inputs, [input])))
otherprod = mul(*(utils.difference(inputs, [input]))) yr = real(otherprod)
yr = real(otherprod) yi = imag(otherprod)
yi = imag(otherprod) if input.type in complex_types:
if input.type in complex_types: retval += [complex(yr * real(gz) + yi * imag(gz),
retval += [complex(yr * real(gz) + yi * imag(gz), yr * imag(gz) - yi * real(gz))]
yr * imag(gz) - yi * real(gz))]
else:
retval += [cast(yr * real(gz) + yi * imag(gz),
input.type.dtype)]
else: else:
retval += [cast(mul(*([gz] + utils.difference(inputs, retval += [yr * real(gz) + yi * imag(gz)]
[input]))),
input.type.dtype)]
else: else:
retval += [None] retval += [mul(*([gz] + utils.difference(inputs,
[input])))]
return retval return retval
...@@ -1697,7 +1696,13 @@ class Sgn(UnaryScalarOp): ...@@ -1697,7 +1696,13 @@ class Sgn(UnaryScalarOp):
return numpy.sign(x) return numpy.sign(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return None,
rval = x.zeros_like()
if rval.type.dtype in discrete_types:
rval = rval.astype(theano.config.floatX)
return [rval]
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
#casting is done by compiler #casting is done by compiler
...@@ -1723,7 +1728,12 @@ class Ceil(UnaryScalarOp): ...@@ -1723,7 +1728,12 @@ class Ceil(UnaryScalarOp):
return numpy.ceil(x) return numpy.ceil(x)
def grad(self, (x,), (gz,)): def grad(self, (x,), (gz,)):
return None, rval = x.zeros_like()
if rval.type.dtype in discrete_types:
rval = rval.astype(theano.config.floatX)
return [rval]
def c_code(self, node, name, (x,), (z,), sub): def c_code(self, node, name, (x,), (z,), sub):
return "%(z)s = ceil(%(x)s);" % locals() return "%(z)s = ceil(%(x)s);" % locals()
...@@ -1735,7 +1745,12 @@ class Floor(UnaryScalarOp): ...@@ -1735,7 +1745,12 @@ class Floor(UnaryScalarOp):
return numpy.floor(x) return numpy.floor(x)
def grad(self, (x,), (gz,)): def grad(self, (x,), (gz,)):
return None, rval = x.zeros_like()
if rval.type.dtype in discrete_types:
rval = rval.astype(theano.config.floatX)
return [rval]
def c_code(self, node, name, (x,), (z,), sub): def c_code(self, node, name, (x,), (z,), sub):
return "%(z)s = floor(%(x)s);" % locals() return "%(z)s = floor(%(x)s);" % locals()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论