提交 e2202bc7 authored 作者: Rémi Louf's avatar Rémi Louf 提交者: Brandon T. Willard

Remove use of `aesara.tensor.nnet` in other tests

上级 4c685afb
......@@ -5,10 +5,9 @@ from aesara.configdefaults import config
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import get_test_value
from aesara.tensor import elemwise as at_elemwise
from aesara.tensor import nnet as at_nnet
from aesara.tensor.math import SoftmaxGrad
from aesara.tensor.math import all as at_all
from aesara.tensor.math import prod
from aesara.tensor.math import log_softmax, prod, softmax
from aesara.tensor.math import sum as at_sum
from aesara.tensor.type import matrix, tensor, vector
from tests.link.jax.test_basic import compare_jax_and_py
......@@ -76,7 +75,7 @@ def test_jax_CAReduce():
def test_softmax(axis):
x = matrix("x")
x.tag.test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
out = at_nnet.softmax(x, axis=axis)
out = softmax(x, axis=axis)
fgraph = FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
......@@ -85,7 +84,7 @@ def test_softmax(axis):
def test_logsoftmax(axis):
x = matrix("x")
x.tag.test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
out = at_nnet.logsoftmax(x, axis=axis)
out = log_softmax(x, axis=axis)
fgraph = FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
......
......@@ -7,7 +7,6 @@ from aesara.configdefaults import config
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import get_test_value
from aesara.scalar.basic import Composite
from aesara.tensor import nnet as at_nnet
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.math import all as at_all
from aesara.tensor.math import (
......@@ -128,10 +127,6 @@ def test_nnet():
fgraph = FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = at_nnet.ultra_fast_sigmoid(x)
fgraph = FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = softplus(x)
fgraph = FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
......
......@@ -444,7 +444,7 @@ def test_grad_inrange():
def test_grad_abs():
a = fscalar("a")
b = aesara.tensor.nnet.relu(a)
b = 0.5 * (a + aesara.tensor.abs(a))
c = aesara.grad(b, a)
f = aesara.function([a], c, mode=Mode(optimizer=None))
# Currently Aesara return 0.5, but it isn't sure it won't change
......
......@@ -43,7 +43,6 @@ from aesara.tensor.math import all as at_all
from aesara.tensor.math import dot, exp, mean, sigmoid
from aesara.tensor.math import sum as at_sum
from aesara.tensor.math import tanh
from aesara.tensor.nnet import categorical_crossentropy
from aesara.tensor.random import normal
from aesara.tensor.random.utils import RandomStream
from aesara.tensor.shape import Shape_i, reshape, specify_shape
......@@ -58,7 +57,6 @@ from aesara.tensor.type import (
fscalar,
ftensor3,
fvector,
imatrix,
iscalar,
ivector,
lscalar,
......@@ -3810,36 +3808,6 @@ class TestExamples:
# TODO FIXME: What is this testing? At least assert something.
def test_grad_two_scans(self):
# data input & output
x = tensor3("x")
t = imatrix("t")
# forward pass
W = shared(
np.random.default_rng(utt.fetch_seed()).random((2, 2)).astype("float32"),
name="W",
borrow=True,
)
def forward_scanner(x_t):
a2_t = dot(x_t, W)
y_t = softmax_graph(a2_t)
return y_t
y, _ = scan(fn=forward_scanner, sequences=x, outputs_info=[None])
# loss function
def error_scanner(y_t, t_t):
return mean(categorical_crossentropy(y_t, t_t))
L, _ = scan(fn=error_scanner, sequences=[y, t], outputs_info=[None])
L = mean(L)
# backward pass
grad(L, [W])
def _grad_mout_helper(self, n_iters, mode):
rng = np.random.default_rng(utt.fetch_seed())
n_hid = 3
......
差异被折叠。
......@@ -25,10 +25,9 @@ from aesara.graph.basic import Apply
from aesara.graph.op import Op
from aesara.tensor.math import argmax, dot
from aesara.tensor.math import max as at_max
from aesara.tensor.nnet import conv, conv2d
from aesara.tensor.shape import unbroadcast
from aesara.tensor.signal.pool import Pool
from aesara.tensor.type import TensorType, matrix, vector
from aesara.tensor.type import matrix, vector
from tests import unittest_tools as utt
......@@ -302,62 +301,6 @@ class TestRopLop(RopLopChecker):
v2 = scan_f()
assert np.allclose(v1, v2), f"Rop mismatch: {v1} {v2}"
def test_conv(self):
for conv_op in [conv.conv2d, conv2d]:
for border_mode in ["valid", "full"]:
image_shape = (2, 2, 4, 5)
filter_shape = (2, 2, 2, 3)
image_dim = len(image_shape)
filter_dim = len(filter_shape)
input = TensorType(aesara.config.floatX, [False] * image_dim)(
name="input"
)
filters = TensorType(aesara.config.floatX, [False] * filter_dim)(
name="filter"
)
ev_input = TensorType(aesara.config.floatX, [False] * image_dim)(
name="ev_input"
)
ev_filters = TensorType(aesara.config.floatX, [False] * filter_dim)(
name="ev_filters"
)
def sym_conv2d(input, filters):
return conv_op(input, filters, border_mode=border_mode)
output = sym_conv2d(input, filters).flatten()
yv = Rop(output, [input, filters], [ev_input, ev_filters])
mode = None
if aesara.config.mode == "FAST_COMPILE":
mode = "FAST_RUN"
rop_f = function(
[input, filters, ev_input, ev_filters],
yv,
on_unused_input="ignore",
mode=mode,
)
sy, _ = aesara.scan(
lambda i, y, x1, x2, v1, v2: (grad(y[i], x1) * v1).sum()
+ (grad(y[i], x2) * v2).sum(),
sequences=at.arange(output.shape[0]),
non_sequences=[output, input, filters, ev_input, ev_filters],
mode=mode,
)
scan_f = function(
[input, filters, ev_input, ev_filters],
sy,
on_unused_input="ignore",
mode=mode,
)
dtype = aesara.config.floatX
image_data = np.random.random(image_shape).astype(dtype)
filter_data = np.random.random(filter_shape).astype(dtype)
ev_image_data = np.random.random(image_shape).astype(dtype)
ev_filter_data = np.random.random(filter_shape).astype(dtype)
v1 = rop_f(image_data, filter_data, ev_image_data, ev_filter_data)
v2 = scan_f(image_data, filter_data, ev_image_data, ev_filter_data)
assert np.allclose(v1, v2), f"Rop mismatch: {v1} {v2}"
def test_join(self):
tv = np.asarray(self.rng.uniform(size=(10,)), aesara.config.floatX)
t = aesara.shared(tv)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论