提交 e991b90b authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #3709 from nouiz/crash_scan

Fix BUG IN FGRAPH.REPLACE OR A LISTENER
...@@ -1607,28 +1607,32 @@ class Scan(PureOp): ...@@ -1607,28 +1607,32 @@ class Scan(PureOp):
inner_ins_shapes = [] inner_ins_shapes = []
out_equivalent = OrderedDict() out_equivalent = OrderedDict()
# The two following blocks are commented as it cause in some
# cases extra scans in the graph. See gh-XXX for the
# investigation.
# We skip the first outer input as it is the total or current number # We skip the first outer input as it is the total or current number
# of iterations. # of iterations.
# sequences # sequences
seqs_shape = [x[1:] for x in input_shapes[1:1 + self.n_seqs]] seqs_shape = [x[1:] for x in input_shapes[1:1 + self.n_seqs]]
inner_seqs = self.inputs[:self.n_seqs] # inner_seqs = self.inputs[:self.n_seqs]
outer_seqs = node.inputs[1:1 + self.n_seqs] # outer_seqs = node.inputs[1:1 + self.n_seqs]
for in_s, out_s in izip(inner_seqs, outer_seqs): # for in_s, out_s in izip(inner_seqs, outer_seqs):
out_equivalent[in_s] = out_s[0] # out_equivalent[in_s] = out_s[0]
# mit_mot, mit_sot, sit_sot # mit_mot, mit_sot, sit_sot
outer_inp_idx = 1 + self.n_seqs # outer_inp_idx = 1 + self.n_seqs
inner_inp_idx = self.n_seqs # inner_inp_idx = self.n_seqs
n_outs = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot n_outs = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
outs_shape = [] outs_shape = []
for idx in xrange(n_outs): for idx in xrange(n_outs):
mintap = abs(min(self.tap_array[idx])) # mintap = abs(min(self.tap_array[idx]))
for k in self.tap_array[idx]: for k in self.tap_array[idx]:
outs_shape += [input_shapes[idx + self.n_seqs + 1][1:]] outs_shape += [input_shapes[idx + self.n_seqs + 1][1:]]
corresponding_tap = node.inputs[outer_inp_idx][mintap + k] # corresponding_tap = node.inputs[outer_inp_idx][mintap + k]
out_equivalent[self.inputs[inner_inp_idx]] = corresponding_tap # out_equivalent[self.inputs[inner_inp_idx]] = corresponding_tap
inner_inp_idx += 1 # inner_inp_idx += 1
outer_inp_idx += 1 # outer_inp_idx += 1
# shared_outs # shared_outs
offset = 1 + self.n_seqs + n_outs offset = 1 + self.n_seqs + n_outs
......
...@@ -693,12 +693,9 @@ class PushOutScanOutput(gof.Optimizer): ...@@ -693,12 +693,9 @@ class PushOutScanOutput(gof.Optimizer):
args = scan_args(node.inputs, node.outputs, args = scan_args(node.inputs, node.outputs,
op.inputs, op.outputs, op.info) op.inputs, op.outputs, op.info)
local_fgraph = gof.FunctionGraph(args.inner_inputs,
args.inner_outputs,
clone=False)
new_scan_node = None new_scan_node = None
local_fgraph_topo = local_fgraph.toposort() local_fgraph_topo = theano.gof.graph.io_toposort(op.inputs, op.outputs)
for nd in local_fgraph_topo: for nd in local_fgraph_topo:
if (isinstance(nd.op, theano.tensor.Dot) and if (isinstance(nd.op, theano.tensor.Dot) and
nd.out in args.inner_out_nit_sot): nd.out in args.inner_out_nit_sot):
...@@ -860,7 +857,6 @@ class PushOutScanOutput(gof.Optimizer): ...@@ -860,7 +857,6 @@ class PushOutScanOutput(gof.Optimizer):
reason="scanOp_pushout_output") reason="scanOp_pushout_output")
break break
return new_scan_node return new_scan_node
def inner_sitsot_only_last_step_used(self, var, scan_args): def inner_sitsot_only_last_step_used(self, var, scan_args):
......
...@@ -2567,25 +2567,27 @@ class T_Scan(unittest.TestCase): ...@@ -2567,25 +2567,27 @@ class T_Scan(unittest.TestCase):
diff = mitsot_m1 + seq1 diff = mitsot_m1 + seq1
next_mitsot_val = mitsot_m2 + diff next_mitsot_val = mitsot_m2 + diff
next_sitsot_val = sitsot_m1 - diff next_sitsot_val = sitsot_m1 - diff
nitsot_out = tensor.AllocEmpty('float32')(next_mitsot_val + nitsot_out = tensor.alloc(numpy.asarray(0., 'float32'),
next_sitsot_val) next_mitsot_val +
next_sitsot_val)
return next_sitsot_val, next_mitsot_val, nitsot_out return next_sitsot_val, next_mitsot_val, nitsot_out
out, updates = theano.scan(fn=step, out, updates = theano.scan(fn=step,
sequences=seq, sequences=seq,
outputs_info=[sitsot_init, outputs_info=[sitsot_init,
{'initial' : mitsot_init, {'initial': mitsot_init,
'taps' : [-2, -1]}, 'taps': [-2, -1]},
None], None],
n_steps=5) n_steps=5)
f = theano.function([seq, sitsot_init, mitsot_init], out[2].shape, f = theano.function([seq, sitsot_init, mitsot_init], out[2].shape,
mode='FAST_RUN') mode='FAST_RUN')
assert(len(scan_nodes_from_fct(f)) == 0) # When Scan.infer_shape will cover more case, there will no scan left.
assert(len(scan_nodes_from_fct(f)) == 1)
output_shape = f(numpy.arange(5), 5, [1, 2]) # This generate a scan crash during execution.
assert(all(output_shape == (5,6))) # output_shape = f(numpy.arange(5), 5, [1, 2])
# assert(all(output_shape == (5, 6)))
# The following test will fail in DebugMode if there are # The following test will fail in DebugMode if there are
# some problems in Scan.infer_shape # some problems in Scan.infer_shape
......
...@@ -212,7 +212,9 @@ class TestPushOutScanOutputDot(object): ...@@ -212,7 +212,9 @@ class TestPushOutScanOutputDot(object):
# not be the result of a Dot # not be the result of a Dot
scan_node = [node for node in f_opt.maker.fgraph.toposort() scan_node = [node for node in f_opt.maker.fgraph.toposort()
if isinstance(node.op, Scan)][0] if isinstance(node.op, Scan)][0]
assert len(scan_node.op.outputs) == 1 # NOTE: WHEN INFER_SHAPE IS REENABLED, BELLOW THE SCAN MUST
# HAVE ONLY 1 OUTPUT.
assert len(scan_node.op.outputs) == 2
assert not isinstance(scan_node.op.outputs[0], T.Dot) assert not isinstance(scan_node.op.outputs[0], T.Dot)
# Ensure that the function compiled with the optimization produces # Ensure that the function compiled with the optimization produces
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论