提交 6c63d25c authored 作者: carriepl's avatar carriepl

Use io_toposort() instead of toposort in PushOutSeqScan and PushOutNonSeqScan

上级 89afb152
...@@ -240,14 +240,11 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -240,14 +240,11 @@ class PushOutNonSeqScan(gof.Optimizer):
clean_inputs, clean_outputs = scan_utils.reconstruct_graph( clean_inputs, clean_outputs = scan_utils.reconstruct_graph(
node.op.inputs, node.op.outputs) node.op.inputs, node.op.outputs)
local_fgraph = gof.FunctionGraph(clean_inputs, local_fgraph_topo = theano.gof.graph.io_toposort(clean_inputs,
clean_outputs, clean_outputs)
clone=False) local_fgraph_outs_set = set(clean_outputs)
local_fgraph_topo = local_fgraph.toposort()
local_fgraph_outs_set = set(local_fgraph.outputs)
local_fgraph_outs_map = dict([(v, k) for k, v in \ local_fgraph_outs_map = dict([(v, k) for k, v in \
enumerate(local_fgraph.outputs)]) enumerate(clean_outputs)])
to_remove_set = set() to_remove_set = set()
to_replace_set = set() to_replace_set = set()
...@@ -452,12 +449,11 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -452,12 +449,11 @@ class PushOutSeqScan(gof.Optimizer):
clean_inputs, clean_outputs = scan_utils.reconstruct_graph( clean_inputs, clean_outputs = scan_utils.reconstruct_graph(
node.op.inputs, node.op.outputs) node.op.inputs, node.op.outputs)
local_fgraph = gof.FunctionGraph(clean_inputs, clean_outputs, local_fgraph_topo = theano.gof.graph.io_toposort(clean_inputs,
clone=False) clean_outputs)
local_fgraph_topo = local_fgraph.toposort() local_fgraph_outs_set = set(clean_outputs)
local_fgraph_outs_set = set(local_fgraph.outputs)
local_fgraph_outs_map = dict([(v,k) for k,v in \ local_fgraph_outs_map = dict([(v,k) for k,v in \
enumerate(local_fgraph.outputs)]) enumerate(clean_outputs)])
to_remove_set = set() to_remove_set = set()
to_replace_set = set() to_replace_set = set()
...@@ -640,7 +636,7 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -640,7 +636,7 @@ class PushOutSeqScan(gof.Optimizer):
if out in local_fgraph_outs_set: if out in local_fgraph_outs_set:
x = node.outputs[local_fgraph_outs_map[out]] x = node.outputs[local_fgraph_outs_map[out]]
_y = replace_with_out[idx] _y = replace_with_out[idx]
ls = local_fgraph.outputs ls = clean_outputs
if out in op.inner_mitsot_outs(ls): if out in op.inner_mitsot_outs(ls):
odx = op.inner_mitsot_outs(ls).index(out) odx = op.inner_mitsot_outs(ls).index(out)
inp = op.outer_mitsot(node)[odx] inp = op.outer_mitsot(node)[odx]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论