提交 9c9bbd0a authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Misc. Scan comment updates

上级 52413fc6
......@@ -896,13 +896,13 @@ def scan(
)
tensor_update = aet.as_tensor_variable(input.update)
sit_sot_inner_outputs.append(tensor_update)
# Not that pos is not a negative index. The sign of pos is used
# Note that `pos` is not a negative index. The sign of `pos` is used
# as a flag to indicate if this output should be part of the
# update rules or part of the standard outputs of scan.
# If `pos` is positive than it corresponds to the standard
# outputs of scan and it refers to output of index `pos`. If `pos`
# is negative that it corresponds to update rules of scan and it
# refers to update rule of index -1 - `pos`.
# update rules or part of the standard outputs of `scan`.
# If `pos` is positive then it corresponds to the standard
# outputs of `scan` and it refers to output of index `pos`. If `pos`
# is negative that it corresponds to update rules of `scan` and it
# refers to the update rule with index `-1 - pos`.
sit_sot_rightOrder.append(-1 - len(sit_sot_shared))
sit_sot_shared.append(input.variable)
givens[input.variable] = new_var
......@@ -913,6 +913,7 @@ def scan(
shared_inner_outputs.append(input.update)
givens[input.variable] = new_var
n_shared_outs += 1
n_sit_sot = len(sit_sot_inner_inputs)
# Step 5.4 Outputs with no taps used in the input
n_nit_sot = 0
......@@ -1006,7 +1007,7 @@ def scan(
# gpuarray is imported here, instead of being imported on top of
# the file because that would force on the user some dependencies that we
# might do not want to. Currently we are working on removing the
# dependencies on sandbox code completeley.
# dependencies on sandbox code completely.
from aesara import gpuarray
if gpuarray.pygpu_activated:
......
......@@ -1736,7 +1736,6 @@ class ScanMerge(GlobalOptimizer):
return ls
for idx, nd in enumerate(nodes):
# Seq
inner_ins[idx].append(rename(nd.op.inner_seqs(nd.op.inputs), idx))
outer_ins += rename(nd.op.outer_seqs(nd.inputs), idx)
......
......@@ -39,7 +39,6 @@ from aesara.tensor.subtensor import set_subtensor
from aesara.tensor.var import TensorConstant
# Logging function for sending warning or info
_logger = logging.getLogger("aesara.scan.utils")
......@@ -402,9 +401,9 @@ def get_updates_and_outputs(ls):
list of outputs and the stopping condition returned by the
lambda expression and arrange them in a predefined order.
WRITEME: what is the type of ls? how is it formatted?
if it's not in the predefined order already, how does
this function know how to put it in that order?
WRITEME: what is the type of ls? how is it formatted? if it's not in the
predefined order already, how does this function know how to put it in that
order?
"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论