提交 040410f4 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Use direct function import

上级 dcc18636
...@@ -6,7 +6,6 @@ import numpy as np ...@@ -6,7 +6,6 @@ import numpy as np
import pytest import pytest
import scipy import scipy
import pytensor
from pytensor import function, grad from pytensor import function, grad
from pytensor import tensor as pt from pytensor import tensor as pt
from pytensor.configdefaults import config from pytensor.configdefaults import config
...@@ -130,7 +129,7 @@ def test_cholesky_grad_indef(): ...@@ -130,7 +129,7 @@ def test_cholesky_grad_indef():
def test_cholesky_infer_shape(): def test_cholesky_infer_shape():
x = matrix() x = matrix()
f_chol = pytensor.function([x], [cholesky(x).shape, cholesky(x, lower=False).shape]) f_chol = function([x], [cholesky(x).shape, cholesky(x, lower=False).shape])
if config.mode != "FAST_COMPILE": if config.mode != "FAST_COMPILE":
topo_chol = f_chol.maker.fgraph.toposort() topo_chol = f_chol.maker.fgraph.toposort()
f_chol.dprint() f_chol.dprint()
...@@ -313,7 +312,7 @@ class TestSolve(utt.InferShapeTester): ...@@ -313,7 +312,7 @@ class TestSolve(utt.InferShapeTester):
b_ndim=len(b_size), b_ndim=len(b_size),
) )
solve_func = pytensor.function([A, b], y) solve_func = function([A, b], y)
X_np = solve_func(A_val.copy(), b_val.copy()) X_np = solve_func(A_val.copy(), b_val.copy())
ATOL = 1e-8 if config.floatX.endswith("64") else 1e-4 ATOL = 1e-8 if config.floatX.endswith("64") else 1e-4
...@@ -444,7 +443,7 @@ class TestSolveTriangular(utt.InferShapeTester): ...@@ -444,7 +443,7 @@ class TestSolveTriangular(utt.InferShapeTester):
b_ndim=len(b_shape), b_ndim=len(b_shape),
) )
f = pytensor.function([A, b], x) f = function([A, b], x)
x_pt = f(A_val, b_val) x_pt = f(A_val, b_val)
x_sp = scipy.linalg.solve_triangular( x_sp = scipy.linalg.solve_triangular(
...@@ -508,8 +507,8 @@ class TestCholeskySolve(utt.InferShapeTester): ...@@ -508,8 +507,8 @@ class TestCholeskySolve(utt.InferShapeTester):
A = matrix() A = matrix()
b = matrix() b = matrix()
self._compile_and_check( self._compile_and_check(
[A, b], # pytensor.function inputs [A, b], # function inputs
[self.op_class(b_ndim=2)(A, b)], # pytensor.function outputs [self.op_class(b_ndim=2)(A, b)], # function outputs
# A must be square # A must be square
[ [
np.asarray(rng.random((5, 5)), dtype=config.floatX), np.asarray(rng.random((5, 5)), dtype=config.floatX),
...@@ -522,8 +521,8 @@ class TestCholeskySolve(utt.InferShapeTester): ...@@ -522,8 +521,8 @@ class TestCholeskySolve(utt.InferShapeTester):
A = matrix() A = matrix()
b = vector() b = vector()
self._compile_and_check( self._compile_and_check(
[A, b], # pytensor.function inputs [A, b], # function inputs
[self.op_class(b_ndim=1)(A, b)], # pytensor.function outputs [self.op_class(b_ndim=1)(A, b)], # function outputs
# A must be square # A must be square
[ [
np.asarray(rng.random((5, 5)), dtype=config.floatX), np.asarray(rng.random((5, 5)), dtype=config.floatX),
...@@ -538,10 +537,10 @@ class TestCholeskySolve(utt.InferShapeTester): ...@@ -538,10 +537,10 @@ class TestCholeskySolve(utt.InferShapeTester):
A = matrix() A = matrix()
b = matrix() b = matrix()
y = self.op_class(lower=True, b_ndim=2)(A, b) y = self.op_class(lower=True, b_ndim=2)(A, b)
cho_solve_lower_func = pytensor.function([A, b], y) cho_solve_lower_func = function([A, b], y)
y = self.op_class(lower=False, b_ndim=2)(A, b) y = self.op_class(lower=False, b_ndim=2)(A, b)
cho_solve_upper_func = pytensor.function([A, b], y) cho_solve_upper_func = function([A, b], y)
b_val = np.asarray(rng.random((5, 1)), dtype=config.floatX) b_val = np.asarray(rng.random((5, 1)), dtype=config.floatX)
...@@ -603,7 +602,7 @@ def test_lu_decomposition( ...@@ -603,7 +602,7 @@ def test_lu_decomposition(
A = tensor("A", shape=shape, dtype=dtype) A = tensor("A", shape=shape, dtype=dtype)
out = lu(A, permute_l=permute_l, p_indices=p_indices) out = lu(A, permute_l=permute_l, p_indices=p_indices)
f = pytensor.function([A], out) f = function([A], out)
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
x = rng.normal(size=shape).astype(config.floatX) x = rng.normal(size=shape).astype(config.floatX)
...@@ -706,7 +705,7 @@ class TestLUSolve(utt.InferShapeTester): ...@@ -706,7 +705,7 @@ class TestLUSolve(utt.InferShapeTester):
x = self.factor_and_solve(A, b, trans=trans, sum=False) x = self.factor_and_solve(A, b, trans=trans, sum=False)
f = pytensor.function([A, b], x) f = function([A, b], x)
x_pt = f(A_val.copy(), b_val.copy()) x_pt = f(A_val.copy(), b_val.copy())
x_sp = scipy.linalg.lu_solve( x_sp = scipy.linalg.lu_solve(
scipy.linalg.lu_factor(A_val.copy()), b_val.copy(), trans=trans scipy.linalg.lu_factor(A_val.copy()), b_val.copy(), trans=trans
...@@ -744,7 +743,7 @@ def test_lu_factor(): ...@@ -744,7 +743,7 @@ def test_lu_factor():
A = matrix() A = matrix()
A_val = rng.normal(size=(5, 5)).astype(config.floatX) A_val = rng.normal(size=(5, 5)).astype(config.floatX)
f = pytensor.function([A], lu_factor(A)) f = function([A], lu_factor(A))
LU, pt_p_idx = f(A_val) LU, pt_p_idx = f(A_val)
sp_LU, sp_p_idx = scipy.linalg.lu_factor(A_val) sp_LU, sp_p_idx = scipy.linalg.lu_factor(A_val)
...@@ -764,7 +763,7 @@ def test_cho_solve(): ...@@ -764,7 +763,7 @@ def test_cho_solve():
A = matrix() A = matrix()
b = matrix() b = matrix()
y = cho_solve((A, True), b) y = cho_solve((A, True), b)
cho_solve_lower_func = pytensor.function([A, b], y) cho_solve_lower_func = function([A, b], y)
b_val = np.asarray(rng.random((5, 1)), dtype=config.floatX) b_val = np.asarray(rng.random((5, 1)), dtype=config.floatX)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论