提交 2b8ea5da authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Following Josh suggestion I renamed a function to something more readable.

I've also make the documentation more clear on the actual behaviour of n_steps.
上级 b04fa107
...@@ -240,9 +240,9 @@ def scan( fn ...@@ -240,9 +240,9 @@ def scan( fn
outputs will have *0 rows*. If the value is negative, ``scan`` outputs will have *0 rows*. If the value is negative, ``scan``
run backwards in time. If the ``go_backwards`` flag is already run backwards in time. If the ``go_backwards`` flag is already
set and also ``n_steps`` is negative, ``scan`` will run forward set and also ``n_steps`` is negative, ``scan`` will run forward
in time. If n stpes is not provided, or evaluates to ``None``, in time. If n stpes is not provided, or is a constant that
``inf`` or ``NaN``, ``scan`` will figure out the amount of evaluates to ``None``, ``inf`` or ``NaN``, ``scan`` will figure
steps it should run given its input sequences. out the amount of steps it should run given its input sequences.
:param truncate_gradient: :param truncate_gradient:
...@@ -454,7 +454,7 @@ def scan( fn ...@@ -454,7 +454,7 @@ def scan( fn
for seq in scan_seqs: for seq in scan_seqs:
lengths_vec.append( seq.shape[0] ) lengths_vec.append( seq.shape[0] )
if not scan_utils.check_NaN_Inf_None(n_steps): if not scan_utils.isNaN_or_Inf_or_None(n_steps):
# ^ N_steps should also be considered # ^ N_steps should also be considered
lengths_vec.append( tensor.as_tensor(n_steps) ) lengths_vec.append( tensor.as_tensor(n_steps) )
...@@ -468,7 +468,7 @@ def scan( fn ...@@ -468,7 +468,7 @@ def scan( fn
# If the user has provided the number of steps, do that regardless ( and # If the user has provided the number of steps, do that regardless ( and
# raise an error if the sequences are not long enough ) # raise an error if the sequences are not long enough )
if scan_utils.check_NaN_Inf_None(n_steps): if scan_utils.isNaN_or_Inf_or_None(n_steps):
actual_n_steps = lengths_vec[0] actual_n_steps = lengths_vec[0]
for contestant in lengths_vec[1:]: for contestant in lengths_vec[1:]:
actual_n_steps = tensor.minimum(actual_n_steps, contestant) actual_n_steps = tensor.minimum(actual_n_steps, contestant)
......
...@@ -205,7 +205,7 @@ def get_updates_and_outputs(outputs_updates): ...@@ -205,7 +205,7 @@ def get_updates_and_outputs(outputs_updates):
return outputs, updates return outputs, updates
def check_NaN_Inf_None(x): def isNaN_or_Inf_or_None(x):
isNone = x is None isNone = x is None
try: try:
isNaN = numpy.isnan(x) isNaN = numpy.isnan(x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论