提交 62696871 authored 作者: sentient07's avatar sentient07

Fixed silly error

上级 846cfabd
......@@ -238,7 +238,7 @@ class GraphToGPU(Optimizer):
# Iterating through output of all the nodes
for n in fgraph.toposort():
for o in n.outputs:
if isinstance(i.type, tensor.TensorType):
if isinstance(o.type, tensor.TensorType):
mapping[o] = GpuFromHost(None)(o)
else:
mapping[o] = o
......@@ -279,16 +279,13 @@ class GraphToGPU(Optimizer):
for o in node.outputs:
mapping[o] = o
for o in fgraph.outputs:
try:
new_o = mapping[o]
except KeyError as k:
pass
if new_o.type != o.type:
assert isinstance(o.type, tensor.TensorType)
assert isinstance(new_o.type, GpuArrayType)
new_o = host_from_gpu(new_o)
fgraph.replace_validate(o, new_o)
for o in fgraph.outputs:
new_o = mapping[o]
if new_o.type != o.type:
assert isinstance(o.type, tensor.TensorType)
assert isinstance(new_o.type, GpuArrayType)
new_o = host_from_gpu(new_o)
fgraph.replace_validate(o, new_o)
gpu_seqopt.register('GraphToGPU', GraphToGPU(),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论