提交 15cf3bc8 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

fixed grad of Mul,Sgn,Ceil, and Floor

上级 18d6bd2f
......@@ -1184,26 +1184,25 @@ class Mul(ScalarOp):
' expected gz type to be complex, got gz with type '+\
str(gz.type))
if output_type in discrete_types:
return [ipt.zeros_like.astype(theano.config.floatX)
for ipt in inputs]
for input in inputs:
if input.type in continuous_types:
if gz.type in complex_types:
# zr+zi = (xr + xi)(yr + yi)
# zr+zi = (xr*yr - xi*yi) + (xr yi + xi yr )
otherprod = mul(*(utils.difference(inputs, [input])))
yr = real(otherprod)
yi = imag(otherprod)
if input.type in complex_types:
retval += [complex(yr * real(gz) + yi * imag(gz),
yr * imag(gz) - yi * real(gz))]
else:
retval += [cast(yr * real(gz) + yi * imag(gz),
input.type.dtype)]
if gz.type in complex_types:
# zr+zi = (xr + xi)(yr + yi)
# zr+zi = (xr*yr - xi*yi) + (xr yi + xi yr )
otherprod = mul(*(utils.difference(inputs, [input])))
yr = real(otherprod)
yi = imag(otherprod)
if input.type in complex_types:
retval += [complex(yr * real(gz) + yi * imag(gz),
yr * imag(gz) - yi * real(gz))]
else:
retval += [cast(mul(*([gz] + utils.difference(inputs,
[input]))),
input.type.dtype)]
retval += [yr * real(gz) + yi * imag(gz)]
else:
retval += [None]
retval += [mul(*([gz] + utils.difference(inputs,
[input])))]
return retval
......@@ -1697,7 +1696,13 @@ class Sgn(UnaryScalarOp):
return numpy.sign(x)
def grad(self, (x, ), (gz, )):
return None,
rval = x.zeros_like()
if rval.type.dtype in discrete_types:
rval = rval.astype(theano.config.floatX)
return [rval]
def c_code(self, node, name, (x, ), (z, ), sub):
#casting is done by compiler
......@@ -1723,7 +1728,12 @@ class Ceil(UnaryScalarOp):
return numpy.ceil(x)
def grad(self, (x,), (gz,)):
return None,
rval = x.zeros_like()
if rval.type.dtype in discrete_types:
rval = rval.astype(theano.config.floatX)
return [rval]
def c_code(self, node, name, (x,), (z,), sub):
return "%(z)s = ceil(%(x)s);" % locals()
......@@ -1735,7 +1745,12 @@ class Floor(UnaryScalarOp):
return numpy.floor(x)
def grad(self, (x,), (gz,)):
return None,
rval = x.zeros_like()
if rval.type.dtype in discrete_types:
rval = rval.astype(theano.config.floatX)
return [rval]
def c_code(self, node, name, (x,), (z,), sub):
return "%(z)s = floor(%(x)s);" % locals()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论