提交 8509cd13 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Refactor Scan tests

- Move test description comments to docstrings - Move optimization tests to `test_opt` - Move `aesara.scan.views` tests to a new `test_views` - Move example-based tests (i.e. mostly non-functionality-focused tests) to a distinct test suite/class - Make rewrite tests specify the presence of the relevant rewrites (assuming that those tests even rely on their namesake rewrites, of course) - Miscellaneous test updates
上级 31ac6dc7
...@@ -392,8 +392,8 @@ def push_out_non_seq_scan(fgraph, node): ...@@ -392,8 +392,8 @@ def push_out_non_seq_scan(fgraph, node):
def push_out_seq_scan(fgraph, node): def push_out_seq_scan(fgraph, node):
r"""Push out the variables inside the `Scan` that depend only on constants and sequences. r"""Push out the variables inside the `Scan` that depend only on constants and sequences.
This optimization resembles `PushOutNonSeqScan` but it tries to push, out of This optimization resembles `push_out_non_seq_scan` but it tries to push--out of
the inner function, the computation that only relies on sequence and the inner function--the computation that only relies on sequence and
non-sequence inputs. The idea behind this optimization is that, when it is non-sequence inputs. The idea behind this optimization is that, when it is
possible to do so, it is generally more computationally efficient to perform possible to do so, it is generally more computationally efficient to perform
a single operation on a large tensor rather then perform that same operation a single operation on a large tensor rather then perform that same operation
......
This source diff could not be displayed because it is too large. You can view the blob instead.
import numpy as np import numpy as np
import pytest
import aesara import aesara
import aesara.tensor.basic as at import aesara.tensor.basic as at
from aesara import function, scan, shared
from aesara.compile.io import In
from aesara.compile.mode import get_default_mode
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.gradient import Rop, grad, jacobian from aesara.gradient import grad, jacobian
from aesara.graph.basic import clone_replace
from aesara.scan.op import Scan from aesara.scan.op import Scan
from aesara.scan.opt import ScanMerge
from aesara.scan.utils import until
from aesara.tensor.blas import Dot22
from aesara.tensor.elemwise import Elemwise from aesara.tensor.elemwise import Elemwise
from aesara.tensor.math import Dot, dot, sigmoid from aesara.tensor.math import Dot, dot, sigmoid
from aesara.tensor.math import sum as at_sum from aesara.tensor.math import sum as at_sum
from aesara.tensor.math import tanh from aesara.tensor.math import tanh
from aesara.tensor.type import matrix, tensor3, vector from aesara.tensor.shape import reshape, shape, specify_shape
from aesara.tensor.type import (
dmatrix,
dvector,
iscalar,
ivector,
matrix,
scalar,
tensor3,
vector,
)
from tests import unittest_tools as utt from tests import unittest_tools as utt
from tests.scan.test_basic import asarrayX, scan_nodes_from_fct
mode = aesara.compile.mode.get_mode(config.mode) mode = aesara.compile.mode.get_mode(config.mode)
class TestGaussNewton: class TestRemoveConstantsAndUnusedInputsScan:
""" mode = get_default_mode().including("scan")
Regression test for code exhibiting various optimization errors.
This test case is based on code by Sigurd Spieckermann. def test_remove_constants_and_unused_inputs_scan_non_seqs(self):
""" """Test the rewrite `remove_constants_and_unused_inputs_scan` for non-sequences."""
W = matrix(name="W")
v = ivector(name="v")
y1, _ = scan(
lambda i, W: W[i], sequences=v, outputs_info=None, non_sequences=[W]
)
y2, _ = scan(
lambda i, _, W: W[i],
sequences=v,
outputs_info=None,
non_sequences=[W[0], W],
)
y3, _ = scan(
lambda i, W, _: W[i],
sequences=v,
outputs_info=None,
non_sequences=[W, W[0]],
)
y4, _ = scan(
lambda i, _, _2, W: W[i],
sequences=v,
outputs_info=None,
non_sequences=[W[0], W[0], W],
)
y5, _ = scan(
lambda i, _, W, _2: W[i],
sequences=v,
outputs_info=None,
non_sequences=[W[0], W, W[0]],
)
y6, _ = scan(
lambda i, W, _, _2: W[i],
sequences=v,
outputs_info=None,
non_sequences=[W, W[0], W[0]],
)
# TODO: y7 have problem during run time. I think it should
# raise an error during the scan construction.
# y7, _ = scan(lambda i, W, _, _2: W[i], sequences=v,
# outputs_info=None, non_sequences=[v, W[0], W])
W_val = np.random.normal(size=(3, 3)).astype(config.floatX)
exp_val = W_val[np.r_[1, 2]]
for out in [y1, y2, y3, y4, y5, y6]:
f = function([W, v], out, mode=self.mode)
res = f(W_val, [1, 2])
assert np.array_equal(res, exp_val)
scan_nodes = scan_nodes_from_fct(f)
assert len(scan_nodes) == 1
scan_node = scan_nodes[0]
assert len(scan_node.inputs[1:]) == len(set(scan_node.inputs[1:]))
inp = scan_node.op.inner_non_seqs(scan_node.op.inputs)
assert len(inp) == 1
assert len(inp) == len(set(inp))
inp = scan_node.op.outer_non_seqs(scan_node.inputs)
assert len(inp) == 1
assert len(inp) == len(set(inp))
def test_remove_constants_and_unused_inputs_scan_seqs(self):
"""Test the opt remove_constants_and_unused_inputs_scan for sequences."""
W = matrix(name="W")
v = ivector(name="v")
vv = matrix(name="vv")
y1, _ = scan(
lambda i, W: W[i], sequences=v, outputs_info=None, non_sequences=[W]
)
y2, _ = scan(
lambda i, _, W: W[i], sequences=[v, v], outputs_info=None, non_sequences=W
)
y3, _ = scan(
lambda i, _, W: W[i],
sequences=[v, vv[0]],
outputs_info=None,
non_sequences=W,
)
y4, _ = scan(
lambda _, i, W: W[i],
sequences=[vv[0], v],
outputs_info=None,
non_sequences=W,
)
y5, _ = scan(
lambda _, i, _2, W: W[i],
sequences=[vv, v, vv[0]],
outputs_info=None,
non_sequences=W,
)
y6, _ = scan(
lambda _, _2, i, W: W[i],
sequences=[vv[0], vv, v],
outputs_info=None,
non_sequences=W,
)
y7, _ = scan(
lambda i, _, _2, W: W[i],
sequences=[v, vv[0], vv[0]],
outputs_info=None,
non_sequences=W,
)
y8, _ = scan(
lambda _, i, W, _2, _3: W[i],
sequences=[vv[0], v],
outputs_info=None,
non_sequences=[W, W[0], W[0]],
)
def setup_method(self): W_val = np.random.normal(size=(3, 3)).astype(config.floatX)
self.rng = np.random.default_rng(utt.fetch_seed()) exp_val = W_val[np.r_[1, 2]]
def _run(self, num_features, num_timesteps, batch_size, mode): for out in [y1, y2, y3, y4, y5, y6, y7, y8]:
# determine shapes of inputs and targets depending on the batch size f = function(
if batch_size == 1: [W, v, vv],
inputs_size = (num_timesteps, num_features) out,
targets_size = (num_timesteps, 1) on_unused_input="ignore",
else: mode=self.mode,
inputs_size = (num_timesteps, batch_size, num_features) )
targets_size = (num_timesteps, batch_size, 1)
# make inputs and targets shared variables res = f(W_val, [1, 2], W_val)
inputs = aesara.shared( assert np.array_equal(res, exp_val)
self.rng.uniform(size=inputs_size).astype(config.floatX), borrow=True
scan_nodes = scan_nodes_from_fct(f)
assert len(scan_nodes) == 1
scan_node = scan_nodes[0]
assert len(scan_node.inputs[1:]) == len(set(scan_node.inputs[1:]))
inp = scan_node.op.inner_seqs(scan_node.op.inputs)
assert len(inp) == 1
inp = scan_node.op.outer_seqs(scan_node.inputs)
assert len(inp) == 1
inp = scan_node.op.inner_non_seqs(scan_node.op.inputs)
assert len(inp) == 1
inp = scan_node.op.outer_non_seqs(scan_node.inputs)
assert len(inp) == 1
class TestPushOutDot:
mode = get_default_mode().including("scan")
def test_pushout_all(self):
W1 = matrix("W1")
W2 = matrix("W2")
h0 = vector("h0")
def lambda_fn(h, W1, W2):
return dot(h, W1 + W2)
o, _ = scan(lambda_fn, non_sequences=[h0, W1, W2], n_steps=5)
f = function([h0, W1, W2], o, mode=self.mode)
scan_nodes = [x for x in f.maker.fgraph.toposort() if isinstance(x.op, Scan)]
assert len(scan_nodes) == 0
seed = utt.fetch_seed()
rng = np.random.default_rng(seed)
floatX = config.floatX
v_h = np.array(rng.uniform(size=(2,)), dtype=floatX)
v_W1 = np.array(rng.uniform(size=(2, 2)), dtype=floatX)
v_W2 = np.array(rng.uniform(size=(2, 2)), dtype=floatX)
v_out = np.dot(v_h, v_W1 + v_W2)
sol = np.zeros((5, 2))
# This line is here to make sol have the same shape as the output of
# aesara. Note that what we ask aesara to do is to repeat the 2
# elements vector v_out 5 times
sol[:, :] = v_out
utt.assert_allclose(sol, f(v_h, v_W1, v_W2))
def test_pushout_while(self):
"""
Ensure that the optimizations for Scan that push computation out of
the Scan don't alter the result for 'as_while' scans.
"""
W1 = matrix("W1")
W2 = matrix("W2")
step_indices = vector("step_indices")
def lambda_fn(step_idx, W1, W2):
until_condition = until(step_idx > 2)
return dot(W1, W2), until_condition
# Compile a function with the optimization
o, _ = scan(
lambda_fn, sequences=[step_indices, W1], non_sequences=[W2], n_steps=5
) )
targets = aesara.shared(
self.rng.uniform(size=targets_size).astype(config.floatX), borrow=True f = function([W1, W2, step_indices], o, mode=self.mode)
# Compule an aesara function without the optimization
o, _ = scan(
lambda_fn,
sequences=[step_indices, W1],
non_sequences=[W2],
n_steps=5,
mode="FAST_COMPILE",
) )
# create symbolic inputs and targets variables f_ref = function([W1, W2, step_indices], o, mode=self.mode)
if batch_size == 1:
x = matrix("inputs") # Compare the results of the two implementations
t = matrix("targets") input_values = [
else: np.random.default_rng(utt.fetch_seed()).random((5, 5)).astype("float32"),
x = tensor3("inputs") np.random.default_rng(utt.fetch_seed()).random((5, 5)).astype("float32"),
t = tensor3("inputs") np.arange(5).astype("float32"),
x.tag.test_value = inputs.get_value(borrow=True) ]
t.tag.test_value = targets.get_value(borrow=True)
out = f(*input_values)
out_ref = f_ref(*input_values)
utt.assert_allclose(out, out_ref)
def test_pushout(self):
W1 = matrix("W1")
W2 = matrix("W2")
h0 = vector("h0")
def lambda_fn(h, W1, W2):
return dot(h, W1 + W2)
o, _ = scan(lambda_fn, outputs_info=h0, non_sequences=[W1, W2], n_steps=5)
f = function([h0, W1, W2], o, mode=self.mode)
# create a set of parameters for a simple RNN scan_node = [x for x in f.maker.fgraph.toposort() if isinstance(x.op, Scan)][0]
W_xh = aesara.shared( assert (
(0.01 * self.rng.uniform(size=(num_features, 10))).astype(config.floatX), len(
borrow=True, [
x
for x in scan_node.op.fn.maker.fgraph.toposort()
if isinstance(x.op, Elemwise)
]
)
== 0
) )
W_hh = aesara.shared(
(0.01 * self.rng.uniform(size=(10, 10))).astype(config.floatX), borrow=True def test_pushout_nomodif(self):
inp = matrix("inp")
def fn(i, i_tm1):
return i + 10, i_tm1
([i_t, i_tm1], _) = scan(
fn,
sequences=[inp],
outputs_info=[np.asarray([0.0, 0.0], config.floatX), None],
) )
W_hy = aesara.shared( f = function([inp], [i_t, i_tm1])
(0.01 * self.rng.uniform(size=(10, 1))).astype(config.floatX), borrow=True val = np.arange(10).reshape(5, 2).astype(config.floatX)
ret = f(val)
utt.assert_allclose(ret[0], val + 10)
utt.assert_allclose(
ret[1], [[0.0, 0.0], [10.0, 11.0], [12.0, 13.0], [14.0, 15.0], [16.0, 17.0]]
) )
b_h = aesara.shared(np.zeros(10).astype(config.floatX), borrow=True)
b_y = aesara.shared(np.zeros(1).astype(config.floatX), borrow=True)
params = [W_xh, W_hh, W_hy, b_h, b_y]
# recurrent function class TestPushOutNonSeqScan:
def step(x_t, h_tm1):
h = tanh(dot(h_tm1, W_hh) + dot(x_t, W_xh) + b_h)
return h
# build recurrent graph
if batch_size == 1:
h_0 = at.alloc(0.0, 10).astype(config.floatX)
else:
h_0 = at.alloc(0.0, batch_size, 10).astype(config.floatX)
h, updates = aesara.scan(step, sequences=[x], outputs_info=[h_0])
# network output
y = dot(h, W_hy) + b_y
# Create Gauss-Newton-Matrix object. Not really of any use here, but I
# need it for Hessian-Free optimization.
gn = GaussNewtonMatrix(y)
# compute MSE
cost = ((t - y) ** 2).sum(axis=1).mean()
# Compute the cost at some other point in the parameter
# space. Not really of any use here, but this is how I do it
# during certain iterations of CG in the HF algorithm. There,
# it's in fact `pi + current update proposal`. For simplicity,
# I just multiply by 2 here.
cost_ = aesara.clone_replace(cost, replace={pi: 2 * pi for pi in params})
# Compute Gauss-Newton-Matrix times some vector `v` which is `p` in CG,
# but for simplicity, I just take the parameters vector because it's
# already there.
Gv = gn(v=params, cost=cost, parameters=params, damp=at.constant(1.0))
# compile Aesara function
f = aesara.function([], [cost_] + Gv, givens={x: inputs, t: targets}, mode=mode)
# execute
f()
def test_batch(self):
# This runs fine. The batch size is set to something greater than 1,
# i.e. the data is represented by a tensor3 object.
self._run(100, 10, batch_size=5, mode=mode)
def test_nobatch(self):
# This used to give an error due to optimization "scan_merge_inouts".
# The batch size is set to 1 and the data is represented by a matrix.
self._run(100, 10, batch_size=1, mode=mode)
class GaussNewtonMatrix:
def __init__(self, s):
# `s` is the linear network outputs, i.e. the network output
# without having applied the activation function
self._s = s
def __call__(self, v, cost, parameters, damp):
# compute Gauss-Newton Matrix right-multiplied by `v`
Jv = Rop(self._s, parameters, v)
HJv = grad(at_sum(grad(cost, self._s) * Jv), self._s, consider_constant=[Jv])
JHJv = grad(at_sum(HJv * self._s), parameters, consider_constant=[HJv, Jv])
# apply Tikhonov damping
JHJv = [JHJvi + damp * vi for JHJvi, vi in zip(JHJv, v)]
return JHJv
class TestPushOutScanOutputDot:
""" """
Test class for the PushOutScanOutput optimizer in the case where the inner Tests for the `push_out_non_seq_scan` optimization in the case where the inner
function of a scan op has an output which is the result of a Dot product function of a `Scan` `Op` has an output which is the result of a `Dot` product
on a non-sequence matrix input to scan and a vector that is the result of on a non-sequence matrix input to `Scan` and a vector that is the result of
computation in the inner function. computation in the inner function.
""" """
def test_pushout_seqs(self):
def init_predictive_output(inputs, targets, hyp, x_star, s_star):
E = hyp.shape[0]
def init_K(i, X, Y):
XX = X.sum(1).reshape((X.shape[0], 1))
K = XX + XX.T
return K.sum()
beta, K_updts = scan(
init_K, sequences=at.arange(E), non_sequences=[inputs, targets]
)
# mean
def predict_mean_i(i, x_star, s_star, X, beta, h):
n, D = shape(X)
# rescale every dimension by the corresponding inverse lengthscale
iL = at.diag(h[i, :D])
inp = (X - x_star).dot(iL)
# compute the mean
B = iL.dot(s_star).dot(iL)
t = inp.dot(B)
lb = (inp * t).sum() + beta.sum()
Mi = at_sum(lb) * h[i, D]
return Mi
(M), M_updts = scan(
predict_mean_i,
sequences=at.arange(E),
non_sequences=[x_star, s_star, inputs, beta, hyp],
)
return M
# some initializations
hypx = np.log(np.tile([1, 1, 1, 1, 1, 1, 0.01], (3, 1)))
# variables used in the following expressions
hyp = shared(hypx)
inputs = dmatrix("X")
targets = dmatrix("Y")
x_star = dvector("x_star")
s_star = dmatrix("s_star")
M = init_predictive_output(inputs, targets, hyp, x_star, s_star)
X = np.random.default_rng(utt.fetch_seed()).random((10, 4))
Y = np.random.default_rng(utt.fetch_seed()).random((10, 3))
test_m = np.random.default_rng(utt.fetch_seed()).random((4,))
test_s = np.eye(4)
# Compute expected outputs (jacobian of M wrt x_star)
dfdm = function(
[inputs, targets, x_star, s_star],
[
grad(M[0], x_star),
grad(M[1], x_star),
grad(M[2], x_star),
],
)
expected_output = dfdm(X, Y, test_m, test_s)
# equivalent code for the jacobian using scan
dMdm, dMdm_updts = scan(
lambda i, M, x: grad(M[i], x),
sequences=at.arange(M.shape[0]),
non_sequences=[M, x_star],
)
dfdm = function([inputs, targets, x_star, s_star], [dMdm[0], dMdm[1], dMdm[2]])
scan_output = dfdm(X, Y, test_m, test_s)
dMdm_j = jacobian(M, x_star)
dfdm_j = function(
[inputs, targets, x_star, s_star], [dMdm_j[0], dMdm_j[1], dMdm_j[2]]
)
jacobian_outputs = dfdm_j(X, Y, test_m, test_s)
utt.assert_allclose(expected_output, scan_output)
utt.assert_allclose(expected_output, jacobian_outputs)
@config.change_flags(on_opt_error="raise")
def test_pushout_seqs2(self):
x = matrix()
outputs, updates = scan(
lambda x: [x * x, at.constant(0).copy().copy()],
n_steps=2,
sequences=[],
non_sequences=[],
outputs_info=[x, None],
)
# Compile an Aesara function where any optimization error will lead to
# an exception being raised
function([x], outputs, updates=updates)
@config.change_flags(on_opt_error="raise")
def test_pushout_nonseq(self):
"""
This test was created for a crashed that occurred during the
optimization `PushOutNonSeqScan` when it attempted to a scan node with
two outputs but only providing a replacement for one of those
outputs. This led the optimization to raise an exception.
"""
outputs, _ = scan(lambda x: (x * x, x), non_sequences=[2], n_steps=2)
f = function(inputs=[], outputs=outputs)
outs = f()
expected_outs = [[4, 4], [2, 2]]
utt.assert_allclose(outs, expected_outs)
def test_dot_not_output(self): def test_dot_not_output(self):
# Test the case where the vector input to the dot is not already an """
# output of the inner function. Test the case where the vector input to the dot is not already an
output of the inner function.
"""
v = vector() v = vector()
m = matrix() m = matrix()
...@@ -179,8 +458,10 @@ class TestPushOutScanOutputDot: ...@@ -179,8 +458,10 @@ class TestPushOutScanOutputDot:
utt.assert_allclose(output_opt, output_no_opt) utt.assert_allclose(output_opt, output_no_opt)
def test_dot_nitsot_output(self): def test_dot_nitsot_output(self):
# Test the case where the vector input to the dot is already a nitsot """
# output of the inner function. Test the case where the vector input to the dot is already a nitsot
output of the inner function.
"""
a = matrix() a = matrix()
b = matrix() b = matrix()
...@@ -224,8 +505,10 @@ class TestPushOutScanOutputDot: ...@@ -224,8 +505,10 @@ class TestPushOutScanOutputDot:
utt.assert_allclose(output_opt[1], output_no_opt[1]) utt.assert_allclose(output_opt[1], output_no_opt[1])
def test_dot_sitsot_output(self): def test_dot_sitsot_output(self):
# Test the case where the vector input to the dot is not already a """
# non-nitsot (in this case a sitsot) output of the inner function. Test the case where the vector input to the dot is not already a
non-nitsot (in this case a sitsot) output of the inner function.
"""
a = matrix() a = matrix()
b = matrix() b = matrix()
...@@ -268,19 +551,56 @@ class TestPushOutScanOutputDot: ...@@ -268,19 +551,56 @@ class TestPushOutScanOutputDot:
utt.assert_allclose(output_opt[1], output_no_opt[1]) utt.assert_allclose(output_opt[1], output_no_opt[1])
class TestPushOutSumOfDot: class TestPushOutAddScan:
""" """
Test case for the PushOutScanOutput optimizer in the case where the scan Test case for the `push_out_add_scan` optimization in the case where the `Scan`
is used to compute the sum over the dot products between the corresponding is used to compute the sum over the dot products between the corresponding
elements of two list of matrices. elements of two list of matrices.
TODO FIXME XXX: These aren't real tests; they simply confirm that a few
graph that could be relevant to the push-out optimizations can be compiled
and evaluated. None of them confirm that a push-out optimization has been
performed.
""" """
def test_sum_dot(self):
A = matrix("A")
B = matrix("B")
S, _ = scan(
lambda x1, x2, u: u + dot(x1, x2),
sequences=[A.dimshuffle(0, 1, "x"), B.dimshuffle(0, "x", 1)],
outputs_info=[at.zeros_like(A)],
)
f = function([A, B], S.owner.inputs[0][-1])
rng = np.random.default_rng(utt.fetch_seed())
vA = rng.uniform(size=(5, 5)).astype(config.floatX)
vB = rng.uniform(size=(5, 5)).astype(config.floatX)
utt.assert_allclose(f(vA, vB), np.dot(vA.T, vB))
def test_pregreedy_optimizer(self):
W = at.zeros((5, 4))
bv = at.zeros((5,))
bh = at.zeros((4,))
v = matrix("v")
(bv_t, bh_t), _ = scan(
lambda _: [bv, bh], sequences=v, outputs_info=[None, None]
)
chain, _ = scan(
lambda x: dot(dot(x, W) + bh_t, W.T) + bv_t,
outputs_info=v,
n_steps=2,
)
# TODO FIXME: Make this a real test and assert something.
function([v], chain)(np.zeros((3, 5), dtype=config.floatX))
def test_machine_translation(self): def test_machine_translation(self):
# This test case comes from https://github.com/rizar/scan-grad-speed and """
# is an example of actual computation done with scan in the context of This test case comes from https://github.com/rizar/scan-grad-speed and
# machine translation is an example of actual computation done with scan in the context of
# machine translation.
# 'dim' has been reduced from 1000 to 5 to make the test run faster
`dim` has been reduced from 1000 to 5 to make the test run faster
"""
# Parameters from an actual machine translation run # Parameters from an actual machine translation run
batch_size = 80 batch_size = 80
...@@ -376,7 +696,7 @@ class TestPushOutSumOfDot: ...@@ -376,7 +696,7 @@ class TestPushOutSumOfDot:
utt.assert_allclose(f_opt_output, f_no_opt_output) utt.assert_allclose(f_opt_output, f_no_opt_output)
def test_non_zero_init(self): def test_non_zero_init(self):
# Test the case where the initial value for the nitsot output is non-zero """Test the case where the initial value for the nitsot output is non-zero."""
input1 = tensor3() input1 = tensor3()
input2 = tensor3() input2 = tensor3()
...@@ -427,3 +747,706 @@ class TestPushOutSumOfDot: ...@@ -427,3 +747,706 @@ class TestPushOutSumOfDot:
output_no_opt = f_no_opt(input1_value, input2_value, input3_value) output_no_opt = f_no_opt(input1_value, input2_value, input3_value)
utt.assert_allclose(output_opt, output_no_opt) utt.assert_allclose(output_opt, output_no_opt)
class TestScanMerge:
mode = get_default_mode().including("scan")
def test_basic(self):
x = vector()
y = vector()
def sum(s):
return s + 1
sx, upx = scan(sum, sequences=[x])
sy, upy = scan(sum, sequences=[y])
f = function(
[x, y], [sx, sy], mode=self.mode.excluding("scan_pushout_seqs_ops")
)
topo = f.maker.fgraph.toposort()
scans = [n for n in topo if isinstance(n.op, Scan)]
assert len(scans) == 2
sx, upx = scan(sum, sequences=[x], n_steps=2)
sy, upy = scan(sum, sequences=[y], n_steps=3)
f = function(
[x, y], [sx, sy], mode=self.mode.excluding("scan_pushout_seqs_ops")
)
topo = f.maker.fgraph.toposort()
scans = [n for n in topo if isinstance(n.op, Scan)]
assert len(scans) == 2
sx, upx = scan(sum, sequences=[x], n_steps=4)
sy, upy = scan(sum, sequences=[y], n_steps=4)
f = function(
[x, y], [sx, sy], mode=self.mode.excluding("scan_pushout_seqs_ops")
)
topo = f.maker.fgraph.toposort()
scans = [n for n in topo if isinstance(n.op, Scan)]
assert len(scans) == 1
sx, upx = scan(sum, sequences=[x])
sy, upy = scan(sum, sequences=[x])
f = function([x], [sx, sy], mode=self.mode.excluding("scan_pushout_seqs_ops"))
topo = f.maker.fgraph.toposort()
scans = [n for n in topo if isinstance(n.op, Scan)]
assert len(scans) == 1
sx, upx = scan(sum, sequences=[x])
sy, upy = scan(sum, sequences=[x], mode="FAST_COMPILE")
f = function([x], [sx, sy], mode=self.mode.excluding("scan_pushout_seqs_ops"))
topo = f.maker.fgraph.toposort()
scans = [n for n in topo if isinstance(n.op, Scan)]
assert len(scans) == 1
sx, upx = scan(sum, sequences=[x])
sy, upy = scan(sum, sequences=[x], truncate_gradient=1)
f = function([x], [sx, sy], mode=self.mode.excluding("scan_pushout_seqs_ops"))
topo = f.maker.fgraph.toposort()
scans = [n for n in topo if isinstance(n.op, Scan)]
assert len(scans) == 2
def test_three_scans(self):
r"""
This test checks a case where we have three `Scan`\s, two of them
cannot be merged together, but the third one can be merged with
either.
"""
x = vector()
y = vector()
def sum(s):
return s + 1
sx, upx = scan(sum, sequences=[x], n_steps=4, name="X")
# We need to use an expression of y rather than y so the toposort
# comes up with the 'Y' scan last.
sy, upy = scan(sum, sequences=[2 * y + 2], n_steps=4, name="Y")
sz, upz = scan(sum, sequences=[sx], n_steps=4, name="Z")
f = function(
[x, y], [sy, sz], mode=self.mode.excluding("scan_pushout_seqs_ops")
)
topo = f.maker.fgraph.toposort()
scans = [n for n in topo if isinstance(n.op, Scan)]
assert len(scans) == 2
rng = np.random.default_rng(utt.fetch_seed())
x_val = rng.uniform(size=(4,)).astype(config.floatX)
y_val = rng.uniform(size=(4,)).astype(config.floatX)
# Run it so DebugMode can detect optimization problems.
f(x_val, y_val)
def test_belongs_to_set(self):
"""
Test the method belongs_to of this class. Specifically see if it
detects the two `Scan` nodes as not being similar.
"""
inps = vector()
state = scalar()
y1, _ = scan(lambda x, y: x * y, sequences=inps, outputs_info=state, n_steps=5)
y2, _ = scan(
lambda x, y: (x + y, until(x > 0)),
sequences=inps,
outputs_info=state,
n_steps=5,
)
scan_node1 = y1.owner.inputs[0].owner
assert isinstance(scan_node1.op, Scan)
scan_node2 = y2.owner.inputs[0].owner
assert isinstance(scan_node2.op, Scan)
opt_obj = ScanMerge()
assert not opt_obj.belongs_to_set(scan_node1, [scan_node2])
assert not opt_obj.belongs_to_set(scan_node2, [scan_node1])
class TestScanInplaceOptimizer:
mode = get_default_mode().including("scan_make_inplace", "inplace")
@utt.assertFailure_fast
def test_simple_rnn(self):
"""Simple RNN; compute inplace version 1."""
rng = np.random.default_rng(utt.fetch_seed())
vW = asarrayX(np.random.uniform())
vW_in = asarrayX(np.random.uniform())
vu0 = asarrayX(rng.uniform(-5.0, 5.0, size=(3,)))
vu1 = asarrayX(rng.uniform(-5.0, 5.0, size=(3,)))
vu2 = asarrayX(rng.uniform(-5.0, 5.0, size=(3,)))
vx0 = asarrayX(rng.uniform())
vx1 = asarrayX(rng.uniform())
u0 = vector("u0")
u1 = vector("u1")
u2 = vector("u2")
mu0 = In(u0, mutable=False)
mu1 = In(u1, mutable=True)
mu2 = In(u2, mutable=True)
x0 = scalar("x0")
x1 = scalar("y0")
W_in = shared(vW_in, "Win")
W = shared(vW, "W")
def f_rnn_shared(u0_t, u1_t, u2_t, x0_tm1, x1_tm1):
return [
u0_t * W_in + x0_tm1 * W + u1_t * u2_t,
u0_t * W_in + x1_tm1 * W + u1_t + u2_t,
]
outputs, updates = scan(
f_rnn_shared,
[u0, u1, u2],
[dict(initial=x0, inplace=u2), dict(initial=x1, inplace=u1)],
[],
n_steps=None,
truncate_gradient=-1,
go_backwards=False,
mode=self.mode,
)
f9 = function(
[mu0, mu1, mu2, x0, x1],
outputs,
updates=updates,
mode=self.mode,
allow_input_downcast=True,
)
scan_node = [x for x in f9.maker.fgraph.toposort() if isinstance(x.op, Scan)]
assert 0 in scan_node[0].op.destroy_map.keys()
assert 1 in scan_node[0].op.destroy_map.keys()
# compute output in numpy
numpy_x0 = np.zeros((3,))
numpy_x1 = np.zeros((3,))
numpy_x0[0] = vu0[0] * vW_in + vx0 * vW + vu1[0] * vu2[0]
numpy_x1[0] = vu0[0] * vW_in + vx1 * vW + vu1[0] + vu2[0]
for i in range(1, 3):
numpy_x0[i] = vu0[i] * vW_in + numpy_x0[i - 1] * vW + vu1[i] * vu2[i]
numpy_x1[i] = vu0[i] * vW_in + numpy_x1[i - 1] * vW + vu1[i] + vu2[i]
# note aesara computes inplace, so call function after numpy
# equivalent is done
(aesara_x0, aesara_x1) = f9(vu0, vu1, vu2, vx0, vx1)
# assert that aesara does what it should
utt.assert_allclose(aesara_x0, numpy_x0)
utt.assert_allclose(aesara_x1, numpy_x1)
@utt.assertFailure_fast
def test_simple_rnn_2(self):
"""Simple RNN; compute inplace version 2."""
rng = np.random.default_rng(utt.fetch_seed())
vW = asarrayX(np.random.uniform())
vW_in = asarrayX(np.random.uniform())
vu0 = asarrayX(rng.uniform(-5.0, 5.0, size=(3,)))
vu1 = asarrayX(rng.uniform(-5.0, 5.0, size=(4,)))
vu2 = asarrayX(rng.uniform(-5.0, 5.0, size=(5,)))
vx0 = asarrayX(rng.uniform())
vx1 = asarrayX(rng.uniform())
u0 = vector("u0")
u1 = vector("u1")
u2 = vector("u2")
mu0 = In(u0, mutable=True)
mu1 = In(u1, mutable=True)
mu2 = In(u2, mutable=True)
x0 = scalar("x0")
x1 = scalar("y0")
W_in = shared(vW_in, "Win")
W = shared(vW, "W")
def f_rnn_shared(u0_t, u1_t, u1_tp1, u2_tm1, u2_t, u2_tp1, x0_tm1, x1_tm1):
return [
u0_t * W_in + x0_tm1 * W + u1_t * u1_tp1,
u0_t * W_in + x1_tm1 * W + u2_tm1 + u2_t + u2_tp1,
]
outputs, updates = scan(
f_rnn_shared,
[u0, dict(input=u1, taps=[0, 1]), dict(input=u2, taps=[-1, 0, +1])],
[dict(initial=x0), dict(initial=x1)],
[],
n_steps=None,
truncate_gradient=-1,
go_backwards=False,
mode=self.mode,
)
f9 = function(
[mu0, mu1, mu2, x0, x1],
outputs,
updates=updates,
mode=self.mode,
allow_input_downcast=True,
)
scan_node = [x for x in f9.maker.fgraph.toposort() if isinstance(x.op, Scan)]
assert 0 in scan_node[0].op.destroy_map.keys()
assert 1 in scan_node[0].op.destroy_map.keys()
# compute output in numpy
numpy_x0 = np.zeros((3,))
numpy_x1 = np.zeros((3,))
numpy_x0[0] = vu0[0] * vW_in + vx0 * vW + vu1[0] * vu1[1]
numpy_x1[0] = vu0[0] * vW_in + vx1 * vW + vu2[0] + vu2[1] + vu2[2]
for i in range(1, 3):
numpy_x0[i] = vu0[i] * vW_in + numpy_x0[i - 1] * vW + vu1[i] * vu1[i + 1]
numpy_x1[i] = (
vu0[i] * vW_in + numpy_x1[i - 1] * vW + vu2[i] + vu2[i + 1] + vu2[i + 2]
)
# note aesara computes inplace, so call function after numpy
# equivalent is done
(aesara_x0, aesara_x1) = f9(vu0, vu1, vu2, vx0, vx1)
# assert that aesara does what it should
utt.assert_allclose(aesara_x0, numpy_x0)
utt.assert_allclose(aesara_x1, numpy_x1)
@utt.assertFailure_fast
def test_inplace3(self):
rng = np.random.default_rng(utt.fetch_seed())
vx0 = asarrayX(rng.uniform())
vx1 = asarrayX(rng.uniform())
x0 = shared(vx0)
x1 = shared(vx1)
outputs, updates = scan(
lambda x, y: (x + asarrayX(1), y + asarrayX(1)), [], [x0, x1], n_steps=3
)
x0 = asarrayX(np.zeros((3,)))
x0[0] = vx0
x0 = at.constant(x0)
to_replace = outputs[0].owner.inputs[0].owner.inputs[1]
outputs = clone_replace(outputs, replace=[(to_replace, x0)])
f9 = function([], outputs, updates=updates, mode=self.mode)
scan_node = [x for x in f9.maker.fgraph.toposort() if isinstance(x.op, Scan)]
assert 0 not in scan_node[0].op.destroy_map.keys()
assert 1 in scan_node[0].op.destroy_map.keys()
class TestSaveMem:
mode = get_default_mode().including("scan_save_mem", "save_mem_new_scan")
def test_save_mem(self):
rng = np.random.default_rng(utt.fetch_seed())
vW_in2 = asarrayX(rng.uniform(-0.5, 0.5, size=(2,)))
vW = asarrayX(rng.uniform(-0.5, 0.5, size=(2, 2)))
vWout = asarrayX(rng.uniform(-0.5, 0.5, size=(2,)))
vW_in1 = asarrayX(rng.uniform(-0.5, 0.5, size=(2, 2)))
v_u1 = asarrayX(rng.uniform(-0.5, 0.5, size=(8, 2)))
v_u2 = asarrayX(rng.uniform(-0.5, 0.5, size=(8,)))
v_x0 = asarrayX(rng.uniform(-0.5, 0.5, size=(2,)))
v_y0 = asarrayX(rng.uniform(size=(3,)))
W_in2 = shared(vW_in2, name="win2")
W = shared(vW, name="w")
W_out = shared(vWout, name="wout")
W_in1 = matrix("win")
u1 = matrix("u1")
u2 = vector("u2")
x0 = vector("x0")
y0 = vector("y0")
def f_rnn_cmpl(u1_t, u2_t, x_tm1, y_tm1, y_tm3, W_in1):
return [
y_tm3 + 1,
dot(u1_t, W_in1) + u2_t * W_in2 + dot(x_tm1, W),
y_tm1 + dot(x_tm1, W_out),
]
_outputs, updates = scan(
f_rnn_cmpl,
[u1, u2],
[None, dict(initial=x0), dict(initial=y0, taps=[-1, -3])],
W_in1,
n_steps=None,
truncate_gradient=-1,
go_backwards=False,
)
outputs = [_outputs[0][-1], _outputs[1][-1], _outputs[2][-1]]
f4 = function(
[u1, u2, x0, y0, W_in1],
outputs,
updates=updates,
allow_input_downcast=True,
mode=self.mode,
)
# compute the values in numpy
v_x = np.zeros((8, 2), dtype=config.floatX)
v_y = np.zeros((8,), dtype=config.floatX)
v_x[0] = np.dot(v_u1[0], vW_in1) + v_u2[0] * vW_in2 + np.dot(v_x0, vW)
v_y[0] = np.dot(v_x0, vWout) + v_y0[2]
for i in range(1, 8):
v_x[i] = np.dot(v_u1[i], vW_in1) + v_u2[i] * vW_in2 + np.dot(v_x[i - 1], vW)
v_y[i] = np.dot(v_x[i - 1], vWout) + v_y[i - 1]
(aesara_dump, aesara_x, aesara_y) = f4(v_u1, v_u2, v_x0, v_y0, vW_in1)
utt.assert_allclose(aesara_x, v_x[-1:])
utt.assert_allclose(aesara_y, v_y[-1:])
def test_save_mem_reduced_number_of_steps(self):
def f_rnn(u_t):
return (
u_t + 1.0,
u_t + 2.0,
u_t + 3.0,
u_t + 4.0,
u_t + 5.0,
u_t + 6.0,
u_t + 7.0,
)
u = vector("u")
idx = iscalar("idx")
jdx = iscalar("jdx")
[x1, x2, x3, x4, x5, x6, x7], updates = scan(
f_rnn, u, n_steps=None, truncate_gradient=-1, go_backwards=False
)
f2 = function(
[u, idx, jdx],
[x1[:2], x2[4], x3[idx], x4[:idx], x5[-10], x6[-jdx], x7[:-jdx]],
updates=updates,
allow_input_downcast=True,
mode=self.mode,
)
# get random initial values
rng = np.random.default_rng(utt.fetch_seed())
v_u = rng.uniform(-5.0, 5.0, size=(20,))
# compute the output in numpy
tx1, tx2, tx3, tx4, tx5, tx6, tx7 = f2(v_u, 3, 15)
utt.assert_allclose(tx1, v_u[:2] + 1.0)
utt.assert_allclose(tx2, v_u[4] + 2.0)
utt.assert_allclose(tx3, v_u[3] + 3.0)
utt.assert_allclose(tx4, v_u[:3] + 4.0)
utt.assert_allclose(tx5, v_u[-10] + 5.0)
utt.assert_allclose(tx6, v_u[-15] + 6.0)
utt.assert_allclose(tx7, v_u[:-15] + 7.0)
def test_save_mem_store_steps(self):
def f_rnn(u_t, x1_tm1, x1_tm3, x2_tm1, x3tm2, x3_tm1, x4_tm1):
return (
u_t + 1.0,
u_t + 2.0,
u_t + 3.0,
u_t + 4.0,
u_t + 5.0,
u_t + 6.0,
u_t + 7.0,
)
u = vector("u")
x10 = vector("x10")
x20 = scalar("x20")
x30 = vector("x30")
x40 = scalar("x40")
[x1, x2, x3, x4, x5, x6, x7], updates = scan(
f_rnn,
u,
[
None,
None,
None,
dict(initial=x10, taps=[-1, -2]),
x20,
dict(initial=x30, taps=[-1, -2]),
x40,
],
n_steps=None,
truncate_gradient=-1,
go_backwards=False,
)
f2 = function(
[u, x10, x20, x30, x40],
[x1[-7], x2[-3:-1], x3[-6:], x4[-1], x5[-1]],
updates=updates,
allow_input_downcast=True,
mode=self.mode,
)
# get random initial values
rng = np.random.default_rng(utt.fetch_seed())
v_u = rng.uniform(-5.0, 5.0, size=(20,))
# compute the output in numpy
tx1, tx2, tx3, tx4, tx5 = f2(v_u, [0, 0], 0, [0, 0], 0)
utt.assert_allclose(tx1, v_u[-7] + 1.0)
utt.assert_allclose(tx2, v_u[-3:-1] + 2.0)
utt.assert_allclose(tx3, v_u[-6:] + 3.0)
utt.assert_allclose(tx4, v_u[-1] + 4.0)
utt.assert_allclose(tx5, v_u[-1] + 5.0)
def test_savemem_does_not_duplicate_number_of_scan_nodes(self):
var = at.ones(())
values, _ = scan(
lambda x: ([x], (), until(x)),
outputs_info=[var],
n_steps=2,
)
tmp_fn = function([var], values, mode=self.mode)
scan_nodes = [
x for x in tmp_fn.maker.fgraph.toposort() if isinstance(x.op, Scan)
]
assert len(scan_nodes) == 1
def test_savemem_opt(self):
y0 = shared(np.ones((2, 10)))
[y1, y2], updates = scan(
lambda y: [y, y],
outputs_info=[dict(initial=y0, taps=[-2]), None],
n_steps=5,
)
# TODO FIXME: Make this a real test and assert something.
function([], y2.sum(), mode=self.mode)()
def test_savemem_opt_0_step(self):
"""
Test a case where the savemem optimization has the opportunity to
lower the number of steps of a Scan to 0. It tests that the
optimization doesn't do so since Scan nodes with 0
steps are not currently supported and doing so would result in a
crash during the function execution.
"""
def inner_scan_step(x_t_t, h_tm1, w):
return dot(h_tm1, w) + x_t_t
def outer_scan_step(x_t, w):
h, _ = scan(
inner_scan_step,
sequences=[x_t[1:]],
outputs_info=[x_t[0]],
non_sequences=[w],
strict=True,
name="the_inner_scan",
)
return h
def get_outputs(x, w):
features, _ = scan(
outer_scan_step,
sequences=[x],
non_sequences=[w],
strict=True,
name="the_outer_scan",
)
return_val = grad(features.sum(), w)
return return_val
# Compile the aesara function
x = tensor3("x")
w = matrix("w")
f = function(inputs=[x, w], outputs=get_outputs(x, w), mode=self.mode)
# Test the function to ensure it returns valid results
x_value = (
np.random.default_rng(utt.fetch_seed())
.random((2, 2, 3))
.astype(config.floatX)
)
w_value = (
np.random.default_rng(utt.fetch_seed()).random((3, 3)).astype(config.floatX)
)
expected_output = np.tile(x_value[:, 0].sum(0), (3, 1)).transpose()
output = f(x_value, w_value)
utt.assert_allclose(output, expected_output)
@pytest.mark.skip(
reason="The 'assertion' of this test relied on something that no longer exists "
)
def test_subtensor_multiple_slices(self):
r"""
This addresses a bug that happens when you have multiple subtensors
on the output of `Scan`. The bug requires the reshape to be produced,
and it has something to do with how the `Subtensor`\s overlap.
"""
def f_pow2(x_tm1):
return 2 * x_tm1
state = vector("state")
n_steps = iscalar("nsteps")
output, updates = scan(
f_pow2,
[],
state,
[],
n_steps=n_steps,
truncate_gradient=-1,
go_backwards=False,
)
nw_shape = ivector("nw_shape")
# Note that the output is reshaped to 3 dimensional tensor, and
my_f = function(
[state, n_steps, nw_shape],
[reshape(output, nw_shape, ndim=3)[:-2], output[:-4]],
updates=updates,
allow_input_downcast=True,
)
nodes = [x for x in my_f.maker.fgraph.toposort() if isinstance(x.op, Scan)]
# This assertion fails if savemem optimization failed on scan
if config.mode != "FAST_COMPILE":
assert nodes[0].op._scan_savemem_visited
rng = np.random.default_rng(utt.fetch_seed())
my_f(rng.uniform(size=(3,)), 4, np.int64([2, 2, 3]))
def test_inner_replace_dot():
"""
This tests that rewrites are applied to the inner-graph.
In particular, BLAS-based rewrites that remove the original dot product.
This was previously a test with a name that implied it was testing the
`Scan` push-out rewrites, but it wasn't testing that at all, because the
rewrites were never being applied.
"""
W = matrix("W")
h = matrix("h")
mode = get_default_mode().including("scan") # .excluding("BlasOpt")
o, _ = scan(
lambda hi, him1, W: (hi, dot(hi + him1, W)),
outputs_info=[at.zeros([h.shape[1]]), None],
sequences=[h],
non_sequences=[W],
mode=mode,
)
f = function([W, h], o, mode=mode)
scan_nodes = [x for x in f.maker.fgraph.toposort() if isinstance(x.op, Scan)]
assert len(scan_nodes) == 1
scan_op = scan_nodes[0].op
assert not any(isinstance(n.op, Dot) for n in scan_op.fn.maker.fgraph.apply_nodes)
def test_alloc_inputs1():
W1 = matrix("W1")
W2 = matrix("W2")
h0 = vector("h0")
def lambda_fn(h, W1, W2):
return dot(h, W1 * W2)
o, _ = scan(
lambda_fn,
outputs_info=h0,
non_sequences=[W1, at.zeros_like(W2)],
n_steps=5,
)
f = function([h0, W1, W2], o, mode=get_default_mode().including("scan"))
scan_node = [x for x in f.maker.fgraph.toposort() if isinstance(x.op, Scan)][0]
assert (
len(
[
x
for x in scan_node.op.fn.maker.fgraph.toposort()
if isinstance(x.op, Elemwise)
]
)
== 0
)
@pytest.mark.skip(
reason="This tests depends on an optimization for "
"scan that has not been implemented yet."
)
def test_alloc_inputs2():
W1 = matrix()
W2 = matrix()
h0 = vector()
def lambda_fn(W1, h, W2):
return W1 * dot(h, W2)
o, _ = scan(
lambda_fn,
sequences=at.zeros_like(W1),
outputs_info=h0,
non_sequences=[at.zeros_like(W2)],
n_steps=5,
)
f = function([h0, W1, W2], o, mode=get_default_mode().including("scan"))
scan_node = [x for x in f.maker.fgraph.toposort() if isinstance(x.op, Scan)][0]
assert (
len(
[
x
for x in scan_node.op.fn.maker.fgraph.toposort()
if isinstance(x.op, Elemwise)
]
)
== 0
)
def test_alloc_inputs3():
_W1 = matrix()
_W2 = matrix()
_h0 = vector()
W1 = specify_shape(_W1, (3, 3))
W2 = specify_shape(_W2, (3, 3))
h0 = specify_shape(_h0, (3,))
def lambda_fn(W1, h, W2):
return W1 * dot(h, W2)
o, _ = scan(
lambda_fn,
sequences=at.zeros_like(W1),
outputs_info=h0,
non_sequences=[at.zeros_like(W2)],
n_steps=5,
)
# TODO FIXME: This result depends on unrelated rewrites in the "fast" mode.
f = function([_h0, _W1, _W2], o, mode="FAST_RUN")
scan_node = [x for x in f.maker.fgraph.toposort() if isinstance(x.op, Scan)][0]
assert len(scan_node.op.inputs) == 1
def test_opt_order():
"""
Verify that scan optimizations are applied before blas
optimizations.
This is needed as otherwise, the dot won't become a dot22
so it will be slower and won't get transferred to the gpu.
"""
x = matrix("x")
A = matrix("A")
z, updates = scan(dot, sequences=[], non_sequences=[x, A], n_steps=2)
f = function([x, A], z, mode="FAST_RUN")
topo = f.maker.fgraph.toposort()
assert any(isinstance(node.op, Dot22) for node in topo)
vx = np.array([[1.0, 1.0], [2.0, 2.0]], dtype=config.floatX)
vA = np.array([[1.0, 1.0], [1.0, 0.0]], dtype=config.floatX)
vR = np.array([[[2, 1], [4, 2]], [[2, 1], [4, 2]]], dtype=config.floatX)
utt.assert_allclose(f(vx, vA), vR)
import numpy as np
import aesara.tensor as at
from aesara import config, function, grad, shared
from aesara.compile.mode import FAST_RUN
from aesara.scan.views import foldl, foldr
from aesara.scan.views import map as at_map
from aesara.scan.views import reduce as at_reduce
from aesara.tensor.type import scalar, vector
from tests import unittest_tools as utt
from tests.scan.test_basic import clone_optimized_graph, grab_scan_node
def test_reduce():
v = vector("v")
s = scalar("s")
result, updates = at_reduce(lambda x, y: x + y, v, s)
f = function([v, s], result, updates=updates, allow_input_downcast=True)
rng = np.random.default_rng(utt.fetch_seed())
v_v = rng.uniform(-5.0, 5.0, size=(5,))
assert abs(np.sum(v_v) - f(v_v, 0.0)) < 1e-3
def test_map():
v = vector("v")
abs_expr, abs_updates = at_map(
lambda x: abs(x), v, [], truncate_gradient=-1, go_backwards=False
)
f = function([v], abs_expr, updates=abs_updates, allow_input_downcast=True)
rng = np.random.default_rng(utt.fetch_seed())
vals = rng.uniform(-5.0, 5.0, size=(10,))
abs_vals = abs(vals)
aesara_vals = f(vals)
utt.assert_allclose(abs_vals, aesara_vals)
def test_reduce_memory_consumption():
x = shared(np.asarray(np.random.uniform(size=(10,)), dtype=config.floatX))
o, _ = at_reduce(
lambda v, acc: acc + v,
x,
at.constant(np.asarray(0.0, dtype=config.floatX)),
)
mode = FAST_RUN
mode = mode.excluding("inplace")
f1 = function([], o, mode=mode)
inputs, outputs = clone_optimized_graph(f1)
scan_nodes = grab_scan_node(outputs[0])
assert scan_nodes is not None
scan_node = scan_nodes[0]
f1 = function(inputs, scan_node.inputs[2])
# Originally, the shape would have been 1 due to the SaveMem
# optimization reducing the size to the number of taps (in this case
# 1) provided to the inner function. Now, because of the memory-reuse
# feature in Scan it can be 2 because SaveMem needs to keep a
# larger buffer to avoid aliasing between the inputs and the outputs.
if config.scan__allow_output_prealloc:
assert f1().shape[0] == 2
else:
assert f1().shape[0] == 1
gx = grad(o, x)
f2 = function([], gx)
utt.assert_allclose(f2(), np.ones((10,)))
def test_foldl_memory_consumption():
x = shared(np.asarray(np.random.uniform(size=(10,)), dtype=config.floatX))
o, _ = foldl(
lambda v, acc: acc + v,
x,
at.constant(np.asarray(0.0, dtype=config.floatX)),
)
mode = FAST_RUN
mode = mode.excluding("inplace")
f0 = function([], o, mode=mode)
inputs, outputs = clone_optimized_graph(f0)
scan_nodes = grab_scan_node(outputs[0])
assert scan_nodes is not None
scan_node = scan_nodes[0]
f1 = function(inputs, scan_node.inputs[2])
# Originally, the shape would have been 1 due to the SaveMem
# optimization reducing the size to the number of taps (in this case
# 1) provided to the inner function. Now, because of the memory-reuse
# feature in Scan it can be 2 because SaveMem needs to keep a
# larger buffer to avoid aliasing between the inputs and the outputs.
if config.scan__allow_output_prealloc:
assert f1().shape[0] == 2
else:
assert f1().shape[0] == 1
gx = grad(o, x)
f2 = function([], gx)
utt.assert_allclose(f2(), np.ones((10,)))
def test_foldr_memory_consumption():
x = shared(np.asarray(np.random.uniform(size=(10,)), dtype=config.floatX))
o, _ = foldr(
lambda v, acc: acc + v,
x,
at.constant(np.asarray(0.0, dtype=config.floatX)),
)
mode = FAST_RUN
mode = mode.excluding("inplace")
f1 = function([], o, mode=mode)
inputs, outputs = clone_optimized_graph(f1)
scan_nodes = grab_scan_node(outputs[0])
assert scan_nodes is not None
scan_node = scan_nodes[0]
f1 = function(inputs, scan_node.inputs[2])
# Originally, the shape would have been 1 due to the SaveMem
# optimization reducing the size to the number of taps (in this case
# 1) provided to the inner function. Now, because of the memory-reuse
# feature in Scan it can be 2 because SaveMem needs to keep a
# larger buffer to avoid aliasing between the inputs and the outputs.
if config.scan__allow_output_prealloc:
assert f1().shape[0] == 2
else:
assert f1().shape[0] == 1
gx = grad(o, x)
f2 = function([], gx)
utt.assert_allclose(f2(), np.ones((10,)))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论