提交 1b2c02ae authored 作者: Razvan Pascanu's avatar Razvan Pascanu

remove extra outdated comments + fixed bug in case return_steps =1, and n_steps =1 or -1

上级 5d9d2303
...@@ -716,14 +716,6 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[], ...@@ -716,14 +716,6 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
# make the compilation as fast as possible by not applying any optimization # make the compilation as fast as possible by not applying any optimization
# or conversion to C [ note this region is not important for performance # or conversion to C [ note this region is not important for performance
# so we can do stuff as unoptimal as we wish ] # so we can do stuff as unoptimal as we wish ]
'''
Why did I use gof.graph.inputs to pick the inputs here ??
dummy_f = function(filter(lambda x: isinstance(x,gof.Variable) and \
not isinstance(x,SharedVariable) and not isinstance(x,gof.Constant), \
reversed(gof.graph.inputs(dummy_args))), outputs, updates = updates, mode = \
compile.mode.Mode(linker = 'py', optimizer = None) )
'''
if n_fixed_steps in [-1,1]: if n_fixed_steps in [-1,1]:
''' We do have a special case here, namely is so might happen that ''' We do have a special case here, namely is so might happen that
whatever we have in dummy_args is not sufficient to compile the whatever we have in dummy_args is not sufficient to compile the
...@@ -872,22 +864,8 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[], ...@@ -872,22 +864,8 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
else: else:
# If we do not actually need scan # If we do not actually need scan
for pos, inner_out in enumerate(inner_fn_outs): for pos, inner_out in enumerate(inner_fn_outs):
# check if we are suppose to return just the last step if isinstance(inner_out.type, tensor.TensorType):
# we treat this case differently because the tensor we return inner_fn_outs[pos] = tensor.unbroadcast( tensor.shape_padleft(inner_out),0)
# in this case is different (it has one dimension less)
if return_steps.has_key(pos):
if return_steps[pos] != 1:
# if we return more then one step, we need to add
# one more dimension to our output and make it
# unbroadcastable
inner_fn_outs[pos] = tensor.unbroadcast(
tensor.shape_padleft(inner_out),0)
else:
# same if we do not have any information about how many
# steps we should return (to read return everything in this
# case
inner_fn_outs[pos] = tensor.unbroadcast(
tensor.shape_padleft(inner_out),0)
values = inner_fn_outs values = inner_fn_outs
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论