提交 62696871 authored 作者: sentient07's avatar sentient07

Fixed silly error

上级 846cfabd
...@@ -238,7 +238,7 @@ class GraphToGPU(Optimizer): ...@@ -238,7 +238,7 @@ class GraphToGPU(Optimizer):
# Iterating through output of all the nodes # Iterating through output of all the nodes
for n in fgraph.toposort(): for n in fgraph.toposort():
for o in n.outputs: for o in n.outputs:
if isinstance(i.type, tensor.TensorType): if isinstance(o.type, tensor.TensorType):
mapping[o] = GpuFromHost(None)(o) mapping[o] = GpuFromHost(None)(o)
else: else:
mapping[o] = o mapping[o] = o
...@@ -280,10 +280,7 @@ class GraphToGPU(Optimizer): ...@@ -280,10 +280,7 @@ class GraphToGPU(Optimizer):
mapping[o] = o mapping[o] = o
for o in fgraph.outputs: for o in fgraph.outputs:
try:
new_o = mapping[o] new_o = mapping[o]
except KeyError as k:
pass
if new_o.type != o.type: if new_o.type != o.type:
assert isinstance(o.type, tensor.TensorType) assert isinstance(o.type, tensor.TensorType)
assert isinstance(new_o.type, GpuArrayType) assert isinstance(new_o.type, GpuArrayType)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论