提交 599a4da4 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Fix Theano to properly use the functions created for grabbing different

inputs.
上级 3f6b9b51
...@@ -655,7 +655,7 @@ class Scan(PureOp): ...@@ -655,7 +655,7 @@ class Scan(PureOp):
self.n_shared_outs) self.n_shared_outs)
return list_inputs[offset:] return list_inputs[offset:]
def outer_non_seqs(self, list_inputs: def outer_non_seqs(self, list_inputs):
offset = (1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot + offset = (1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot +
self.n_sit_sot + self.n_nit_sot + self.n_shared_outs) self.n_sit_sot + self.n_nit_sot + self.n_shared_outs)
return list_inputs[offset:] return list_inputs[offset:]
......
...@@ -317,12 +317,12 @@ def scan_make_inplace(node): ...@@ -317,12 +317,12 @@ def scan_make_inplace(node):
info['inplace'] = True info['inplace'] = True
# inputs corresponding to sequences and n_steps # inputs corresponding to sequences and n_steps
ls_begin = node.inputs[:1 + op.n_seqs] ls_begin = node.inputs[:1 + op.n_seqs]
ls = op.outer_mitmot(node) ls = op.outer_mitmot(node.inputs)
ls += op.outer_mitsot(node) ls += op.outer_mitsot(node.inputs)
ls += op.outer_sitsot(node) ls += op.outer_sitsot(node.inputs)
ls_end = op.outer_shared(node) ls_end = op.outer_shared(node.inputs)
ls_end += op.outer_nitsot(node) ls_end += op.outer_nitsot(node.inputs)
ls_end += op.outer_non_seqs(node) ls_end += op.outer_non_seqs(node.inputs)
n_outs = len(ls) n_outs = len(ls)
for idx in xrange(n_outs): for idx in xrange(n_outs):
if ls[idx] in ls[:idx]: if ls[idx] in ls[:idx]:
...@@ -844,54 +844,54 @@ class ScanMerge(gof.Optimizer): ...@@ -844,54 +844,54 @@ class ScanMerge(gof.Optimizer):
for idx, nd in enumerate(nodes): for idx, nd in enumerate(nodes):
# Seq # Seq
inner_ins += rename(nd.op.inner_seqs(), idx) inner_ins += rename(nd.op.inner_seqs(nd.op.inputs), idx)
outer_ins += rename(nd.op.outer_seqs(nd), idx) outer_ins += rename(nd.op.outer_seqs(nd.inputs), idx)
for idx, nd in enumerate(nodes): for idx, nd in enumerate(nodes):
# MitMot # MitMot
inner_ins += rename(nd.op.inner_mitmot(), idx) inner_ins += rename(nd.op.inner_mitmot(nd.op.inputs), idx)
inner_outs += nd.op.inner_mitmot_outs() inner_outs += nd.op.inner_mitmot_outs(nd.op.outputs)
info['tap_array'] += nd.op.mitmot_taps() info['tap_array'] += nd.op.mitmot_taps()
info['mit_mot_out_slices'] += nd.op.mitmot_out_taps() info['mit_mot_out_slices'] += nd.op.mitmot_out_taps()
outer_ins += rename(nd.op.outer_mitmot(nd), idx) outer_ins += rename(nd.op.outer_mitmot(nd.inputs), idx)
outer_outs += nd.op.outer_mitmot_outs(nd) outer_outs += nd.op.outer_mitmot_outs(nd.outputs)
for idx, nd in enumerate(nodes): for idx, nd in enumerate(nodes):
# MitSot # MitSot
inner_ins += rename(nd.op.inner_mitsot(), idx) inner_ins += rename(nd.op.inner_mitsot(nd.op.inputs), idx)
inner_outs += nd.op.inner_mitsot_outs() inner_outs += nd.op.inner_mitsot_outs(nd.op.outputs)
info['tap_array'] += nd.op.mitsot_taps() info['tap_array'] += nd.op.mitsot_taps()
outer_ins += rename(nd.op.outer_mitsot(nd), idx) outer_ins += rename(nd.op.outer_mitsot(nd.inputs), idx)
outer_outs += nd.op.outer_mitsot_outs(nd) outer_outs += nd.op.outer_mitsot_outs(nd.outputs)
for idx, nd in enumerate(nodes): for idx, nd in enumerate(nodes):
# SitSot # SitSot
inner_ins += rename(nd.op.inner_sitsot(), idx) inner_ins += rename(nd.op.inner_sitsot(nd.op.inputs), idx)
info['tap_array'] += [[-1] for x in xrange(nd.op.n_sit_sot)] info['tap_array'] += [[-1] for x in xrange(nd.op.n_sit_sot)]
inner_outs += nd.op.inner_sitsot_outs() inner_outs += nd.op.inner_sitsot_outs(nd.op.outputs)
outer_ins += rename(nd.op.outer_sitsot(nd), idx) outer_ins += rename(nd.op.outer_sitsot(nd.inputs), idx)
outer_outs += nd.op.outer_sitsot_outs(nd) outer_outs += nd.op.outer_sitsot_outs(nd.outputs)
for idx, nd in enumerate(nodes): for idx, nd in enumerate(nodes):
# Shared # Shared
inner_ins += rename(nd.op.inner_shared(), idx) inner_ins += rename(nd.op.inner_shared(nd.op.inputs), idx)
outer_ins += rename(nd.op.outer_shared(nd), idx) outer_ins += rename(nd.op.outer_shared(nd.inputs), idx)
for idx, nd in enumerate(nodes): for idx, nd in enumerate(nodes):
# NitSot # NitSot
inner_outs += nd.op.inner_nitsot_outs() inner_outs += nd.op.inner_nitsot_outs(nd.op.outputs)
outer_ins += rename(nd.op.outer_nitsot(nd), idx) outer_ins += rename(nd.op.outer_nitsot(nd.inputs), idx)
outer_outs += nd.op.outer_nitsot_outs(nd) outer_outs += nd.op.outer_nitsot_outs(nd.outputs)
for idx, nd in enumerate(nodes): for idx, nd in enumerate(nodes):
# Shared # Shared
outer_outs += nd.op.outer_shared_outs(nd) outer_outs += nd.op.outer_shared_outs(nd.outputs)
inner_outs += nd.op.inner_shared_outs() inner_outs += nd.op.inner_shared_outs(nd.op.outputs)
for idx, nd in enumerate(nodes): for idx, nd in enumerate(nodes):
# Non Seqs # Non Seqs
inner_ins += rename(nd.op.inner_non_seqs(), idx) inner_ins += rename(nd.op.inner_non_seqs(nd.op.inputs), idx)
outer_ins += rename(nd.op.outer_non_seqs(nd), idx) outer_ins += rename(nd.op.outer_non_seqs(nd.inputs), idx)
# Add back the number of steps # Add back the number of steps
outer_ins = [nodes[0].inputs[0]] + outer_ins outer_ins = [nodes[0].inputs[0]] + outer_ins
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论