提交 963ea79a authored 作者: sentient07's avatar sentient07

minor changes

上级 62696871
......@@ -11,8 +11,7 @@ from theano.compile import optdb
from theano.compile.ops import shape_i
from theano.gof import (local_optimizer, EquilibriumDB, TopoOptimizer,
SequenceDB, Optimizer, toolbox)
from theano.gof.optdb import LocalGroupDB, Query
from theano.gof.op import Op
from theano.gof.optdb import LocalGroupDB
from theano.ifelse import IfElse
from theano.scalar.basic import Scalar, Pow, Cast
......@@ -218,7 +217,7 @@ gpu_seqopt.register('InputToGpuArrayOptimizer', InputToGpuOptimizer(),
class GraphToGPU(Optimizer):
"""
Transfer the graph as a whole to GPU instead of replacing nodes
Transfer the graph as a whole to GPU instead of transfering node by node.
"""
def add_requirements(self, fgraph):
fgraph.attach_feature(toolbox.ReplaceValidate())
......@@ -279,13 +278,15 @@ class GraphToGPU(Optimizer):
for o in node.outputs:
mapping[o] = o
new_nodes = []
for o in fgraph.outputs:
new_o = mapping[o]
if new_o.type != o.type:
assert isinstance(o.type, tensor.TensorType)
assert isinstance(new_o.type, GpuArrayType)
new_o = host_from_gpu(new_o)
fgraph.replace_validate(o, new_o)
new_nodes.append(new_o)
fgraph.replace_all_validate(zip(fgraph.outputs, new_nodes))
gpu_seqopt.register('GraphToGPU', GraphToGPU(),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论