提交 5a61625a authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix scan_checkpoints with padded sequences

上级 db734614
import pytensor.tensor.basic as ptb import pytensor.tensor.basic as ptb
from pytensor.scan.basic import scan from pytensor.scan.basic import scan
from pytensor.tensor.basic import Join from pytensor.tensor.basic import Join
from pytensor.tensor.math import ceil, eq from pytensor.tensor.math import ceil, eq, neq
from pytensor.tensor.subtensor import set_subtensor from pytensor.tensor.subtensor import set_subtensor
...@@ -130,16 +130,18 @@ def scan_checkpoints( ...@@ -130,16 +130,18 @@ def scan_checkpoints(
# Since padding could be an empty tensor, Join returns a view of s. # Since padding could be an empty tensor, Join returns a view of s.
join = Join(view=0) join = Join(view=0)
for i, s in enumerate(sequences): for i, s in enumerate(sequences):
n = s.shape[0] % save_every_N overshoots_by = s.shape[0] % save_every_N
z = ptb.zeros((n, s.shape[1:]), dtype=s.dtype) overshoots = neq(overshoots_by, 0)
sequences[i] = join(0, [s, z]) n = (save_every_N - overshoots_by) * overshoots
z = ptb.zeros((n, *s.shape[1:]), dtype=s.dtype)
sequences[i] = join(0, s, z)
# Establish the input variables of the outer scan # Establish the input variables of the outer scan
o_sequences = [ o_sequences = [
s.reshape( s.reshape(
[s.shape[0] / save_every_N, save_every_N] [s.shape[0] // save_every_N, save_every_N]
+ [s.shape[i] for i in range(1, s.ndim)], + [s.shape[i] for i in range(1, s.ndim)],
s.ndim + 1, ndim=s.ndim + 1,
) )
for s in sequences for s in sequences
] ]
......
...@@ -5,7 +5,7 @@ from pytensor.compile.function import function ...@@ -5,7 +5,7 @@ from pytensor.compile.function import function
from pytensor.gradient import grad from pytensor.gradient import grad
from pytensor.scan.basic import scan from pytensor.scan.basic import scan
from pytensor.scan.checkpoints import scan_checkpoints from pytensor.scan.checkpoints import scan_checkpoints
from pytensor.tensor.basic import ones_like from pytensor.tensor.basic import arange, ones_like
from pytensor.tensor.type import iscalar, vector from pytensor.tensor.type import iscalar, vector
...@@ -13,15 +13,18 @@ class TestScanCheckpoint: ...@@ -13,15 +13,18 @@ class TestScanCheckpoint:
def setup_method(self): def setup_method(self):
self.k = iscalar("k") self.k = iscalar("k")
self.A = vector("A") self.A = vector("A")
seq = arange(self.k, dtype="float32") + 1
result, _ = scan( result, _ = scan(
fn=lambda prior_result, A: prior_result * A, fn=lambda s, prior_result, A: prior_result * A / s,
outputs_info=ones_like(self.A), outputs_info=ones_like(self.A),
sequences=[seq],
non_sequences=self.A, non_sequences=self.A,
n_steps=self.k, n_steps=self.k,
) )
result_check, _ = scan_checkpoints( result_check, _ = scan_checkpoints(
fn=lambda prior_result, A: prior_result * A, fn=lambda s, prior_result, A: prior_result * A / s,
outputs_info=ones_like(self.A), outputs_info=ones_like(self.A),
sequences=[seq],
non_sequences=self.A, non_sequences=self.A,
n_steps=self.k, n_steps=self.k,
save_every_N=100, save_every_N=100,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论