提交 a2599226 authored 作者: Pierre Luc Carrier's avatar Pierre Luc Carrier

Fix failing tests

上级 1a72f433
...@@ -612,6 +612,11 @@ class PushOutScanOutput(gof.Optimizer): ...@@ -612,6 +612,11 @@ class PushOutScanOutput(gof.Optimizer):
op = node.op op = node.op
# Obtain the list containing the indices, in clean_outputs, of the
# scan op's outputs that are nit_sot (not fed back to the inner fct.)
nitsot_outs = op.inner_nitsot_outs(node.outputs)
idx_nitsot_outs = [node.outputs.index(i) for i in nitsot_outs]
# Construct the list of non_sequences to simplify a few things # Construct the list of non_sequences to simplify a few things
inner_non_seqs = op.inner_non_seqs(clean_inputs) inner_non_seqs = op.inner_non_seqs(clean_inputs)
outer_non_seqs = op.outer_non_seqs(node.inputs) outer_non_seqs = op.outer_non_seqs(node.inputs)
...@@ -621,6 +626,7 @@ class PushOutScanOutput(gof.Optimizer): ...@@ -621,6 +626,7 @@ class PushOutScanOutput(gof.Optimizer):
assert len(inner_seqs) == len(outer_seqs) assert len(inner_seqs) == len(outer_seqs)
new_scan_node = None new_scan_node = None
for nd in local_fgraph.toposort(): for nd in local_fgraph.toposort():
if (isinstance(nd.op, theano.tensor.Dot) and if (isinstance(nd.op, theano.tensor.Dot) and
...@@ -637,12 +643,29 @@ class PushOutScanOutput(gof.Optimizer): ...@@ -637,12 +643,29 @@ class PushOutScanOutput(gof.Optimizer):
concatenating the vectors into a matrix. concatenating the vectors into a matrix.
""" """
# Ensure that the output of the Dot is used somewhere # Go through clean_outputs and pick one that is
# in the outer graph # - Equal to the output of the tensor.Dot
idx_dot_output = clean_outputs.index(nd.out) # - Nit_sot : not fed back to the inner graph because applying
if len(node.outputs[idx_dot_output].clients) == 0: # the optimization in that case would alter the results of
# The Dot's output is not used. It is not worth performing # the function
# the optimization. Move on to the next node # - Used by something outside of the graph to avoid applying
# the optimization needlessly
idx_dot_output = -1
for i in range(len(clean_outputs)):
is_dot_output = (nd.out == clean_outputs[i])
is_nitsot_output = i in idx_nitsot_outs
used_in_outer_graph = (len(node.outputs[i].clients) > 0)
if (is_dot_output and is_nitsot_output and
used_in_outer_graph):
idx_dot_output = i
break
if idx_dot_output == -1:
# The dot has no output that fits the requirements for
# this optimization. Move on to the next node.
continue continue
""" """
...@@ -710,7 +733,8 @@ class PushOutScanOutput(gof.Optimizer): ...@@ -710,7 +733,8 @@ class PushOutScanOutput(gof.Optimizer):
# Perform the Dot on the new scan output. # Perform the Dot on the new scan output.
if idx_matrix_input == 0: if idx_matrix_input == 0:
outer_dot_inputs = [outer_matrix_input, outer_dot_inputs = [outer_matrix_input,
new_outer_output] new_outer_output.transpose()]
outer_dot_output = theano.tensor.dot(*outer_dot_inputs).transpose()
else: # idx_matrix_input == 1 else: # idx_matrix_input == 1
outer_dot_inputs = [new_outer_output, outer_dot_inputs = [new_outer_output,
outer_matrix_input] outer_matrix_input]
...@@ -736,8 +760,8 @@ class PushOutScanOutput(gof.Optimizer): ...@@ -736,8 +760,8 @@ class PushOutScanOutput(gof.Optimizer):
""" """
# Compute the index at which to insert the new output. For a scan Op, # Compute the index at which to insert the new output. For a scan Op,
# the outputs the ordering : mit_mot, mit_sot, sis_sot, nit_sot and # the outputs follow the ordering : mit_mot, mit_sot, sis_sot, nit_sot
# shared_outs # and shared_outs
output_insert_idx = (scan_node.op.info['n_mit_mot'] + output_insert_idx = (scan_node.op.info['n_mit_mot'] +
scan_node.op.info['n_mit_sot'] + scan_node.op.info['n_mit_sot'] +
scan_node.op.info['n_sit_sot'] + scan_node.op.info['n_sit_sot'] +
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论