提交 02a737df authored 作者: James Bergstra's avatar James Bergstra

constant folding uses type attribute "Constant"

上级 7ec1396c
......@@ -1183,16 +1183,14 @@ def constant_folding(node):
return False
storage = [[None] for output in node.outputs]
node.op.perform(node, [x.data for x in node.inputs], storage)
#TODO: think about how to extend to more types
msg = []
for s, output in zip(storage, node.outputs):
if isinstance(s[0], (N.ndarray,int,float)):
msg += [T.TensorConstant(output.type,s[0])]
else:
msg += [gof.Constant(output.type, s[0])]
try:
constant = output.type.Constant
except:
constant = gof.Constant
msg += [constant(output.type, s[0])]
return msg
#TODO: verify this backport!!
#return [(T.TensorConstant if isinstance(s[0], (N.ndarray,int,float)) else gof.Constant)(output.type, s[0]) for s, output in zip(storage, node.outputs)]
register_canonicalize(constant_folding)
register_specialize(constant_folding)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论