提交 599a4da4 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Fix Theano to properly use the functions created for grabbing different

inputs.
上级 3f6b9b51
......@@ -655,7 +655,7 @@ class Scan(PureOp):
self.n_shared_outs)
return list_inputs[offset:]
def outer_non_seqs(self, list_inputs:
def outer_non_seqs(self, list_inputs):
offset = (1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot +
self.n_sit_sot + self.n_nit_sot + self.n_shared_outs)
return list_inputs[offset:]
......
......@@ -317,12 +317,12 @@ def scan_make_inplace(node):
info['inplace'] = True
# inputs corresponding to sequences and n_steps
ls_begin = node.inputs[:1 + op.n_seqs]
ls = op.outer_mitmot(node)
ls += op.outer_mitsot(node)
ls += op.outer_sitsot(node)
ls_end = op.outer_shared(node)
ls_end += op.outer_nitsot(node)
ls_end += op.outer_non_seqs(node)
ls = op.outer_mitmot(node.inputs)
ls += op.outer_mitsot(node.inputs)
ls += op.outer_sitsot(node.inputs)
ls_end = op.outer_shared(node.inputs)
ls_end += op.outer_nitsot(node.inputs)
ls_end += op.outer_non_seqs(node.inputs)
n_outs = len(ls)
for idx in xrange(n_outs):
if ls[idx] in ls[:idx]:
......@@ -844,54 +844,54 @@ class ScanMerge(gof.Optimizer):
for idx, nd in enumerate(nodes):
# Seq
inner_ins += rename(nd.op.inner_seqs(), idx)
outer_ins += rename(nd.op.outer_seqs(nd), idx)
inner_ins += rename(nd.op.inner_seqs(nd.op.inputs), idx)
outer_ins += rename(nd.op.outer_seqs(nd.inputs), idx)
for idx, nd in enumerate(nodes):
# MitMot
inner_ins += rename(nd.op.inner_mitmot(), idx)
inner_outs += nd.op.inner_mitmot_outs()
inner_ins += rename(nd.op.inner_mitmot(nd.op.inputs), idx)
inner_outs += nd.op.inner_mitmot_outs(nd.op.outputs)
info['tap_array'] += nd.op.mitmot_taps()
info['mit_mot_out_slices'] += nd.op.mitmot_out_taps()
outer_ins += rename(nd.op.outer_mitmot(nd), idx)
outer_outs += nd.op.outer_mitmot_outs(nd)
outer_ins += rename(nd.op.outer_mitmot(nd.inputs), idx)
outer_outs += nd.op.outer_mitmot_outs(nd.outputs)
for idx, nd in enumerate(nodes):
# MitSot
inner_ins += rename(nd.op.inner_mitsot(), idx)
inner_outs += nd.op.inner_mitsot_outs()
inner_ins += rename(nd.op.inner_mitsot(nd.op.inputs), idx)
inner_outs += nd.op.inner_mitsot_outs(nd.op.outputs)
info['tap_array'] += nd.op.mitsot_taps()
outer_ins += rename(nd.op.outer_mitsot(nd), idx)
outer_outs += nd.op.outer_mitsot_outs(nd)
outer_ins += rename(nd.op.outer_mitsot(nd.inputs), idx)
outer_outs += nd.op.outer_mitsot_outs(nd.outputs)
for idx, nd in enumerate(nodes):
# SitSot
inner_ins += rename(nd.op.inner_sitsot(), idx)
inner_ins += rename(nd.op.inner_sitsot(nd.op.inputs), idx)
info['tap_array'] += [[-1] for x in xrange(nd.op.n_sit_sot)]
inner_outs += nd.op.inner_sitsot_outs()
outer_ins += rename(nd.op.outer_sitsot(nd), idx)
outer_outs += nd.op.outer_sitsot_outs(nd)
inner_outs += nd.op.inner_sitsot_outs(nd.op.outputs)
outer_ins += rename(nd.op.outer_sitsot(nd.inputs), idx)
outer_outs += nd.op.outer_sitsot_outs(nd.outputs)
for idx, nd in enumerate(nodes):
# Shared
inner_ins += rename(nd.op.inner_shared(), idx)
outer_ins += rename(nd.op.outer_shared(nd), idx)
inner_ins += rename(nd.op.inner_shared(nd.op.inputs), idx)
outer_ins += rename(nd.op.outer_shared(nd.inputs), idx)
for idx, nd in enumerate(nodes):
# NitSot
inner_outs += nd.op.inner_nitsot_outs()
outer_ins += rename(nd.op.outer_nitsot(nd), idx)
outer_outs += nd.op.outer_nitsot_outs(nd)
inner_outs += nd.op.inner_nitsot_outs(nd.op.outputs)
outer_ins += rename(nd.op.outer_nitsot(nd.inputs), idx)
outer_outs += nd.op.outer_nitsot_outs(nd.outputs)
for idx, nd in enumerate(nodes):
# Shared
outer_outs += nd.op.outer_shared_outs(nd)
inner_outs += nd.op.inner_shared_outs()
outer_outs += nd.op.outer_shared_outs(nd.outputs)
inner_outs += nd.op.inner_shared_outs(nd.op.outputs)
for idx, nd in enumerate(nodes):
# Non Seqs
inner_ins += rename(nd.op.inner_non_seqs(), idx)
outer_ins += rename(nd.op.outer_non_seqs(nd), idx)
inner_ins += rename(nd.op.inner_non_seqs(nd.op.inputs), idx)
outer_ins += rename(nd.op.outer_non_seqs(nd.inputs), idx)
# Add back the number of steps
outer_ins = [nodes[0].inputs[0]] + outer_ins
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论