提交 62696871 authored 作者: sentient07's avatar sentient07

Fixed silly error

上级 846cfabd
...@@ -238,7 +238,7 @@ class GraphToGPU(Optimizer): ...@@ -238,7 +238,7 @@ class GraphToGPU(Optimizer):
# Iterating through output of all the nodes # Iterating through output of all the nodes
for n in fgraph.toposort(): for n in fgraph.toposort():
for o in n.outputs: for o in n.outputs:
if isinstance(i.type, tensor.TensorType): if isinstance(o.type, tensor.TensorType):
mapping[o] = GpuFromHost(None)(o) mapping[o] = GpuFromHost(None)(o)
else: else:
mapping[o] = o mapping[o] = o
...@@ -279,16 +279,13 @@ class GraphToGPU(Optimizer): ...@@ -279,16 +279,13 @@ class GraphToGPU(Optimizer):
for o in node.outputs: for o in node.outputs:
mapping[o] = o mapping[o] = o
for o in fgraph.outputs: for o in fgraph.outputs:
try: new_o = mapping[o]
new_o = mapping[o] if new_o.type != o.type:
except KeyError as k: assert isinstance(o.type, tensor.TensorType)
pass assert isinstance(new_o.type, GpuArrayType)
if new_o.type != o.type: new_o = host_from_gpu(new_o)
assert isinstance(o.type, tensor.TensorType) fgraph.replace_validate(o, new_o)
assert isinstance(new_o.type, GpuArrayType)
new_o = host_from_gpu(new_o)
fgraph.replace_validate(o, new_o)
gpu_seqopt.register('GraphToGPU', GraphToGPU(), gpu_seqopt.register('GraphToGPU', GraphToGPU(),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论