提交 c8fc0905 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

made the code to use generic input/output list

Before I used to have the different types of inputs/outputs seperated into lists inside the scan op. I've made a more generic approach, where there is a simple list of all inputs/outputs for scan.
上级 4c1da020
...@@ -42,40 +42,30 @@ _logger = logging.getLogger('theano.scan_module.scan_op') ...@@ -42,40 +42,30 @@ _logger = logging.getLogger('theano.scan_module.scan_op')
class ScanOp(PureOp): class ScanOp(PureOp):
def __init__(self, def __init__(self,
inputs, inputs,
input_states,
parameters,
non_numeric_input_states,
non_numeric_output_states,
output_states,
outputs, outputs,
lengths, lengths,
mintaps, switches,
options, options,
as_repeatUntil): as_repeatUntil):
self.options = options # name/mode/inplace/gpu/profile
self.inputs = inputs self.inputs = inputs
self.input_states = input_states
self.parameters = parameters
self.non_numeric_input_states = non_numeric_input_states
self.non_numeric_output_states = non_numeric_output_states
self.output_states = output_states
self.outputs = outputs self.outputs = outputs
self.switches = switches
self.lengths = lengths self.lengths = lengths
self.mintaps = mintaps self.mintaps = mintaps
self.name = name
self.mode = mode
self.inplace = inplace
self.gpu = gpu
self.as_repeatUntil = as_repeatUntil self.as_repeatUntil = as_repeatUntil
self.profile = profile self.options = options
self.name = options['name']
self.mode = options['mode']
self.inplace = options['inplace']
self.gpu = options['gpu']
self.profile = options['profile']
self.hash_inner_graph = options['hash_inner_graph']
# --Construct the destroy map--
if self.inplace: if self.inplace:
n_outs = (len(output_states) + len(outputs) + for idx in xrange(len(outputs)):
len(non_numeric_output_states)) self.destroy_map[idx] = [idx + 1]
for idx in xrange(n_outs): # --Decide on the default mode--
self.destroy_map[idx] = [idx + 1 + len(inputs)]
mode_instance = compile.mode.get_mode(self.mode) mode_instance = compile.mode.get_mode(self.mode)
# if the default mode is used, and that mode is ProfileMode # if the default mode is used, and that mode is ProfileMode
# then we need to copy the mode otherwise the time for a given # then we need to copy the mode otherwise the time for a given
...@@ -99,14 +89,9 @@ class ScanOp(PureOp): ...@@ -99,14 +89,9 @@ class ScanOp(PureOp):
self.name = 'scan_fn' self.name = 'scan_fn'
def make_node(self, *inputs): def make_node(self, *inputs):
out_types = [] # Checking if arguments are of the right type is done in the scan
out_types.extend( # function
[out_state.type() for out_state in self.output_states]) out_types = [out.type() for out in self.outputs]
out_types.extend(
[out.type() for out in self.outputs])
out_types.extend(
[non_numeric_out_state.type() for non_numeric_out_state in
self.non_numeric_output_states])
return Apply(self, inputs, out_types) return Apply(self, inputs, out_types)
def __eq__(self, other): def __eq__(self, other):
...@@ -119,25 +104,28 @@ class ScanOp(PureOp): ...@@ -119,25 +104,28 @@ class ScanOp(PureOp):
return False return False
elif not len(self.outputs) == len(other.outputs): elif not len(self.outputs) == len(other.outputs):
return False return False
elif self.info != other.info: if self.mintals != other.mintaps:
return False return False
else: # Check if the number of different types of arguments is the same
# If everything went OK up to here, there is still one thing to diff_args = ['inputs', 'outputs', 'lengths', 'mintaps', 'switches']
# check. Namely, do the internal graph represent same for arg in diff_args:
# computations if len(getattr(self, arg)) != len(getattr(other, arg)):
for self_in, other_in in izip(self.inputs, other.inputs): return False
if self_in.type != other_in.type: for x, y in izip(self.inputs, other.inputs):
return False if x.type != y.type:
return False
if not scan_utils.equal_computations(self.outputs, for x, y in izip(self.lengths, other.lengths):
other.outputs, if x.type != y.type:
self.inputs,
other.inputs):
return False return False
# If they do, then they need to match in other small details s_inputs = [self.t] + self.inputs + self.lengths + self.switches
# like name, mode, etc. o_inputs = [other.t] + other.inputs + other.lengths + other.switches
return True givens = dict(izip(s_inputs, o_inputs))
# This part might be slow
for x, y in izip(self.outputs, other.outputs):
if not gof.graph.is_same_graph(x, y, givens=givens):
return False
return True
def __str__(self): def __str__(self):
if self.gpu: if self.gpu:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论