提交 0797283e authored 作者: sentient07's avatar sentient07

Final Cleanup

上级 6548ed04
...@@ -314,9 +314,9 @@ class GraphToGPU(NavigatorOptimizer): ...@@ -314,9 +314,9 @@ class GraphToGPU(NavigatorOptimizer):
for lopt in (self.local_optimizers_map.get(node.op, []) + for lopt in (self.local_optimizers_map.get(node.op, []) +
self.local_optimizers_map.get(type(node.op), []) + self.local_optimizers_map.get(type(node.op), []) +
self.local_optimizers_all): self.local_optimizers_all):
process_count.setdefault(lopt, 0) process_count.setdefault(lopt, 0)
time_opts.setdefault(lopt, 0) time_opts.setdefault(lopt, 0)
node_created.setdefault(lopt, 0) node_created.setdefault(lopt, 0)
for node in topo: for node in topo:
...@@ -326,12 +326,7 @@ class GraphToGPU(NavigatorOptimizer): ...@@ -326,12 +326,7 @@ class GraphToGPU(NavigatorOptimizer):
# Move only if any of the inputs are on the GPU. # Move only if any of the inputs are on the GPU.
move_to_GPU = False move_to_GPU = False
if any([isinstance(i, GpuArrayVariable) or
isinstance(i, GpuArraySharedVariable)
for i in [mapping[v] for v in node.inputs] +
node.outputs]):
move_to_GPU = True
context_name = None context_name = None
for i in [mapping[i] for i in node.inputs]: for i in [mapping[i] for i in node.inputs]:
if isinstance(i.type, GpuArrayType): if isinstance(i.type, GpuArrayType):
...@@ -754,7 +749,7 @@ def local_gpua_dimshuffle(op, context_name, inputs, outputs): ...@@ -754,7 +749,7 @@ def local_gpua_dimshuffle(op, context_name, inputs, outputs):
@register_opt('fast_compile') @register_opt('fast_compile')
@op_lifter([tensor.SpecifyShape]) @op_lifter([tensor.SpecifyShape])
# @register_opt2([tensor.SpecifyShape], 'fast_compile') @register_opt2([tensor.SpecifyShape], 'fast_compile')
def local_gpua_specifyShape(op, context_name, inputs, outputs): def local_gpua_specifyShape(op, context_name, inputs, outputs):
if isinstance(inputs[0].type, GpuArrayType): if isinstance(inputs[0].type, GpuArrayType):
return return
...@@ -1198,7 +1193,7 @@ def local_gpua_dot22scalar(op, context_name, inputs, outputs): ...@@ -1198,7 +1193,7 @@ def local_gpua_dot22scalar(op, context_name, inputs, outputs):
@op_lifter([tensor.basic.Eye]) @op_lifter([tensor.basic.Eye])
@register_opt2([tensor.basic.Eye], 'fast_compile') @register_opt2([tensor.basic.Eye], 'fast_compile')
def local_gpua_eye(op, context_name, inputs, outputs): def local_gpua_eye(op, context_name, inputs, outputs):
return [GpuEye(dtype=op.dtype, context_name=context_name)(*inputs)] return GpuEye(dtype=op.dtype, context_name=context_name)(*inputs)
@register_opt('fast_compile') @register_opt('fast_compile')
...@@ -1298,7 +1293,7 @@ def local_lift_abstractconv2d(op, context_name, inputs, outputs): ...@@ -1298,7 +1293,7 @@ def local_lift_abstractconv2d(op, context_name, inputs, outputs):
if isinstance(outputs[0].type, GpuArrayType): if isinstance(outputs[0].type, GpuArrayType):
# Don't handle this node here, it's already on the GPU. # Don't handle this node here, it's already on the GPU.
return return
local_lift_abstractconv2d_graph(op, context_name, inputs, outputs) return local_lift_abstractconv2d_graph(op, context_name, inputs, outputs)
@register_opt2([AbstractConv2d, @register_opt2([AbstractConv2d,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论