提交 a66acf6a authored 作者: James Bergstra's avatar James Bergstra

fixed hack in Cast c code in scalar basic

上级 a22938cd
......@@ -880,11 +880,7 @@ class Cast(UnaryScalarOp):
def impl(self, input):
return self.ctor(input)
def c_code(self, node, name, (x, ), (z, ), sub):
#HACK: we assume that x has the form 'VARNAME_i',
# and we need the varname to get the dtype.
assert (len(x) > 2) and (x[-2:] == '_i')
varname = x[:-2]
return "%s = (dtype_%s)%s;" % (z, varname, x)
return "%s = (%s)%s;" % (z, node.outputs[0].type.dtype_specs()[1], x)
def grad(self, (x, ), (gz, )):
if x.type in grad_types:
return [cast(gz, x.type.dtype)]
......@@ -893,7 +889,7 @@ class Cast(UnaryScalarOp):
def c_code_cache_version(self):
s = super(Cast, self).c_code_cache_version()
if s:
return (2,) + s
return (3,) + s
else:
return s
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论