提交 963ea79a authored 作者: sentient07's avatar sentient07

minor changes

上级 62696871
...@@ -11,8 +11,7 @@ from theano.compile import optdb ...@@ -11,8 +11,7 @@ from theano.compile import optdb
from theano.compile.ops import shape_i from theano.compile.ops import shape_i
from theano.gof import (local_optimizer, EquilibriumDB, TopoOptimizer, from theano.gof import (local_optimizer, EquilibriumDB, TopoOptimizer,
SequenceDB, Optimizer, toolbox) SequenceDB, Optimizer, toolbox)
from theano.gof.optdb import LocalGroupDB, Query from theano.gof.optdb import LocalGroupDB
from theano.gof.op import Op
from theano.ifelse import IfElse from theano.ifelse import IfElse
from theano.scalar.basic import Scalar, Pow, Cast from theano.scalar.basic import Scalar, Pow, Cast
...@@ -218,7 +217,7 @@ gpu_seqopt.register('InputToGpuArrayOptimizer', InputToGpuOptimizer(), ...@@ -218,7 +217,7 @@ gpu_seqopt.register('InputToGpuArrayOptimizer', InputToGpuOptimizer(),
class GraphToGPU(Optimizer): class GraphToGPU(Optimizer):
""" """
Transfer the graph as a whole to GPU instead of replacing nodes Transfer the graph as a whole to GPU instead of transfering node by node.
""" """
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
fgraph.attach_feature(toolbox.ReplaceValidate()) fgraph.attach_feature(toolbox.ReplaceValidate())
...@@ -279,13 +278,15 @@ class GraphToGPU(Optimizer): ...@@ -279,13 +278,15 @@ class GraphToGPU(Optimizer):
for o in node.outputs: for o in node.outputs:
mapping[o] = o mapping[o] = o
new_nodes = []
for o in fgraph.outputs: for o in fgraph.outputs:
new_o = mapping[o] new_o = mapping[o]
if new_o.type != o.type: if new_o.type != o.type:
assert isinstance(o.type, tensor.TensorType) assert isinstance(o.type, tensor.TensorType)
assert isinstance(new_o.type, GpuArrayType) assert isinstance(new_o.type, GpuArrayType)
new_o = host_from_gpu(new_o) new_o = host_from_gpu(new_o)
fgraph.replace_validate(o, new_o) new_nodes.append(new_o)
fgraph.replace_all_validate(zip(fgraph.outputs, new_nodes))
gpu_seqopt.register('GraphToGPU', GraphToGPU(), gpu_seqopt.register('GraphToGPU', GraphToGPU(),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论