提交 34430def authored 作者: Ian Goodfellow's avatar Ian Goodfellow

fixed grad method for Cast

上级 47f61579
...@@ -1604,10 +1604,10 @@ class Cast(UnaryScalarOp): ...@@ -1604,10 +1604,10 @@ class Cast(UnaryScalarOp):
return "%s = (%s)%s;" % (z, node.outputs[0].type.dtype_specs()[1], x) return "%s = (%s)%s;" % (z, node.outputs[0].type.dtype_specs()[1], x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if x.type in continuous_types and self.o_type in continuous_types: if self.o_type in continuous_types:
return [cast(gz, x.type.dtype)] return [ gz ]
else: else:
return None, return [ x.zeros_like().astype(theano.config.floatX) ]
def c_code_cache_version(self): def c_code_cache_version(self):
s = super(Cast, self).c_code_cache_version() s = super(Cast, self).c_code_cache_version()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论