提交 6c63d25c authored 作者: carriepl's avatar carriepl

Use io_toposort() instead of toposort in PushOutSeqScan and PushOutNonSeqScan

上级 89afb152
......@@ -240,14 +240,11 @@ class PushOutNonSeqScan(gof.Optimizer):
clean_inputs, clean_outputs = scan_utils.reconstruct_graph(
node.op.inputs, node.op.outputs)
local_fgraph = gof.FunctionGraph(clean_inputs,
clean_outputs,
clone=False)
local_fgraph_topo = local_fgraph.toposort()
local_fgraph_outs_set = set(local_fgraph.outputs)
local_fgraph_topo = theano.gof.graph.io_toposort(clean_inputs,
clean_outputs)
local_fgraph_outs_set = set(clean_outputs)
local_fgraph_outs_map = dict([(v, k) for k, v in \
enumerate(local_fgraph.outputs)])
enumerate(clean_outputs)])
to_remove_set = set()
to_replace_set = set()
......@@ -452,12 +449,11 @@ class PushOutSeqScan(gof.Optimizer):
clean_inputs, clean_outputs = scan_utils.reconstruct_graph(
node.op.inputs, node.op.outputs)
local_fgraph = gof.FunctionGraph(clean_inputs, clean_outputs,
clone=False)
local_fgraph_topo = local_fgraph.toposort()
local_fgraph_outs_set = set(local_fgraph.outputs)
local_fgraph_topo = theano.gof.graph.io_toposort(clean_inputs,
clean_outputs)
local_fgraph_outs_set = set(clean_outputs)
local_fgraph_outs_map = dict([(v,k) for k,v in \
enumerate(local_fgraph.outputs)])
enumerate(clean_outputs)])
to_remove_set = set()
to_replace_set = set()
......@@ -640,7 +636,7 @@ class PushOutSeqScan(gof.Optimizer):
if out in local_fgraph_outs_set:
x = node.outputs[local_fgraph_outs_map[out]]
_y = replace_with_out[idx]
ls = local_fgraph.outputs
ls = clean_outputs
if out in op.inner_mitsot_outs(ls):
odx = op.inner_mitsot_outs(ls).index(out)
inp = op.outer_mitsot(node)[odx]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论