Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
ed6ca162
提交
ed6ca162
authored
1月 06, 2024
作者:
jessegrabowski
提交者:
Ricardo Vieira
10月 11, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Inplace Blockwise and core versions of Cholesky and Solve Ops.
Co-authored-by:
Ricardo Vieira
<
28983449+ricardov94@users.noreply.github.com
>
上级
b8dbd4ca
隐藏空白字符变更
内嵌
并排
正在显示
6 个修改的文件
包含
384 行增加
和
41 行删除
+384
-41
op.py
pytensor/graph/op.py
+6
-0
blockwise.py
pytensor/tensor/blockwise.py
+10
-0
blockwise.py
pytensor/tensor/rewriting/blockwise.py
+80
-2
slinalg.py
pytensor/tensor/slinalg.py
+169
-34
test_blockwise.py
tests/tensor/test_blockwise.py
+114
-3
test_slinalg.py
tests/tensor/test_slinalg.py
+5
-2
没有找到文件。
pytensor/graph/op.py
浏览文件 @
ed6ca162
...
...
@@ -583,6 +583,12 @@ class Op(MetaObject):
)
return
self
.
make_py_thunk
(
node
,
storage_map
,
compute_map
,
no_recycling
)
def
inplace_on_inputs
(
self
,
allowed_inplace_inputs
:
list
[
int
])
->
"Op"
:
"""Try to return a version of self that tries to inplace in as many as `allowed_inplace_inputs`."""
# TODO: Document this in the Create your own Op docs
# By default, do nothing
return
self
def
__str__
(
self
):
return
getattr
(
type
(
self
),
"__name__"
,
super
()
.
__str__
())
...
...
pytensor/tensor/blockwise.py
浏览文件 @
ed6ca162
...
...
@@ -45,6 +45,7 @@ class Blockwise(Op):
signature
:
str
|
None
=
None
,
name
:
str
|
None
=
None
,
gufunc_spec
:
tuple
[
str
,
int
,
int
]
|
None
=
None
,
destroy_map
=
None
,
**
kwargs
,
):
"""
...
...
@@ -79,6 +80,15 @@ class Blockwise(Op):
self
.
inputs_sig
,
self
.
outputs_sig
=
_parse_gufunc_signature
(
signature
)
self
.
gufunc_spec
=
gufunc_spec
self
.
_gufunc
=
None
if
destroy_map
is
not
None
:
self
.
destroy_map
=
destroy_map
if
self
.
destroy_map
!=
core_op
.
destroy_map
:
# Note: Should be fine for destroy_map of Blockwise to be more extensive than that of core_op
# But we are not using that anywhere yet, so this check is fine for now
raise
ValueError
(
f
"Blockwise destroy_map {self.destroy_map} must be the same as that of the core_op {core_op} {core_op.destroy_map}"
)
super
()
.
__init__
(
**
kwargs
)
def
__getstate__
(
self
):
...
...
pytensor/tensor/rewriting/blockwise.py
浏览文件 @
ed6ca162
import
itertools
from
pytensor.compile
import
Supervisor
from
pytensor.compile.mode
import
optdb
from
pytensor.graph
import
Constant
,
node_rewriter
from
pytensor.graph.replace
import
vectorize_node
from
pytensor.graph.rewriting.basic
import
copy_stack_trace
,
out2in
from
pytensor.graph.rewriting.basic
import
copy_stack_trace
,
in2out
,
out2in
from
pytensor.tensor.basic
import
Alloc
,
ARange
,
alloc
,
shape_padleft
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.math
import
Dot
...
...
@@ -50,13 +53,14 @@ def local_useless_unbatched_blockwise(fgraph, node):
# We register this rewrite late, so that other rewrites need only target Blockwise Ops
# We do it after position>=60 so that Blockwise inplace rewrites will work also on useless Blockwise Ops
optdb
.
register
(
"local_useless_unbatched_blockwise"
,
out2in
(
local_useless_unbatched_blockwise
,
ignore_newtrees
=
True
),
"fast_run"
,
"fast_compile"
,
"blockwise"
,
position
=
49
,
position
=
60
,
)
...
...
@@ -225,3 +229,77 @@ def local_blockwise_reshape(fgraph, node):
new_out
=
x
.
reshape
([
*
tuple
(
batched_shape
),
*
tuple
(
core_reshape
)])
copy_stack_trace
(
node
.
outputs
[
0
],
new_out
)
return
[
new_out
]
@node_rewriter
(
tracks
=
[
Blockwise
],
inplace
=
True
)
def
blockwise_inplace
(
fgraph
,
node
):
blockwise_op
=
node
.
op
if
blockwise_op
.
destroy_map
:
# Op already has inplace
return
# Find out valid inputs for inplacing
batch_ndim
=
blockwise_op
.
batch_ndim
(
node
)
out_batch_bcast
=
node
.
outputs
[
0
]
.
type
.
broadcastable
[:
batch_ndim
]
protected_inputs
=
[
f
.
protected
for
f
in
fgraph
.
_features
if
isinstance
(
f
,
Supervisor
)
]
protected_inputs
=
list
(
itertools
.
chain
.
from_iterable
(
protected_inputs
))
protected_inputs
.
extend
(
fgraph
.
outputs
)
allowed_inplace_inputs
=
[
idx
for
idx
,
inp
in
enumerate
(
node
.
inputs
)
if
(
# Constants would need to be recreated every time if inplaced
not
isinstance
(
inp
,
Constant
)
# We can only inplace on inputs that are not being broadcasted
# As those are reused across iterations of Blockwise
and
node
.
inputs
[
idx
]
.
type
.
broadcastable
[:
batch_ndim
]
==
out_batch_bcast
# Inputs that are marked as protected or destroyed can't be inplaced
and
not
fgraph
.
has_destroyers
([
inp
])
and
inp
not
in
protected_inputs
)
]
if
not
allowed_inplace_inputs
:
return
None
inplace_core_op
=
blockwise_op
.
core_op
.
inplace_on_inputs
(
allowed_inplace_inputs
=
allowed_inplace_inputs
)
if
not
inplace_core_op
.
destroy_map
:
return
None
# Check Op is not trying to inplace on non-candidate inputs
for
destroyed_inputs
in
inplace_core_op
.
destroy_map
.
values
():
for
destroyed_input
in
destroyed_inputs
:
if
destroyed_input
not
in
allowed_inplace_inputs
:
raise
ValueError
(
f
"Op {blockwise_op.core_op} destroy_map does not respect allowed_inplace_inputs {allowed_inplace_inputs}"
)
# Recreate core_op with inplace
inplace_blockwise_op
=
Blockwise
(
core_op
=
inplace_core_op
,
signature
=
blockwise_op
.
signature
,
name
=
blockwise_op
.
name
,
gufunc_spec
=
blockwise_op
.
gufunc_spec
,
destroy_map
=
inplace_core_op
.
destroy_map
,
)
out
=
inplace_blockwise_op
.
make_node
(
*
node
.
inputs
)
.
outputs
copy_stack_trace
(
node
.
outputs
,
out
)
return
out
optdb
.
register
(
"blockwise_inplace"
,
in2out
(
blockwise_inplace
),
"fast_run"
,
"inplace"
,
position
=
50.1
,
)
pytensor/tensor/slinalg.py
浏览文件 @
ed6ca162
...
...
@@ -28,57 +28,68 @@ logger = logging.getLogger(__name__)
class
Cholesky
(
Op
):
"""
Return a triangular matrix square root of positive semi-definite `x`.
L = cholesky(X, lower=True) implies dot(L, L.T) == X.
Parameters
----------
lower : bool, default=True
Whether to return the lower or upper cholesky factor
on_error : ['raise', 'nan']
If on_error is set to 'raise', this Op will raise a
`scipy.linalg.LinAlgError` if the matrix is not positive definite.
If on_error is set to 'nan', it will return a matrix containing
nans instead.
"""
# TODO: inplace
# TODO: for specific dtypes
# TODO: LAPACK wrapper with in-place behavior, for solve also
__props__
=
(
"lower"
,
"
destructive"
,
"on_error
"
)
__props__
=
(
"lower"
,
"
check_finite"
,
"on_error"
,
"overwrite_a
"
)
gufunc_signature
=
"(m,m)->(m,m)"
def
__init__
(
self
,
*
,
lower
=
True
,
check_finite
=
True
,
on_error
=
"raise"
):
def
__init__
(
self
,
*
,
lower
:
bool
=
True
,
check_finite
:
bool
=
True
,
on_error
:
Literal
[
"raise"
,
"nan"
]
=
"raise"
,
overwrite_a
:
bool
=
False
,
):
self
.
lower
=
lower
self
.
destructive
=
False
self
.
check_finite
=
check_finite
if
on_error
not
in
(
"raise"
,
"nan"
):
raise
ValueError
(
'on_error must be one of "raise" or ""nan"'
)
self
.
on_error
=
on_error
self
.
overwrite_a
=
overwrite_a
if
self
.
overwrite_a
:
self
.
destroy_map
=
{
0
:
[
0
]}
def
infer_shape
(
self
,
fgraph
,
node
,
shapes
):
return
[
shapes
[
0
]]
def
make_node
(
self
,
x
):
x
=
as_tensor_variable
(
x
)
assert
x
.
ndim
==
2
return
Apply
(
self
,
[
x
],
[
x
.
type
()])
if
x
.
type
.
ndim
!=
2
:
raise
TypeError
(
f
"Cholesky only allowed on matrix (2-D) inputs, got {x.type.ndim}-D input"
)
# Call scipy to find output dtype
dtype
=
scipy
.
linalg
.
cholesky
(
np
.
eye
(
1
,
dtype
=
x
.
type
.
dtype
))
.
dtype
return
Apply
(
self
,
[
x
],
[
tensor
(
shape
=
x
.
type
.
shape
,
dtype
=
dtype
)])
def
perform
(
self
,
node
,
inputs
,
outputs
):
x
=
inputs
[
0
]
z
=
outputs
[
0
]
[
x
]
=
inputs
[
out
]
=
outputs
try
:
z
[
0
]
=
scipy
.
linalg
.
cholesky
(
x
,
lower
=
self
.
lower
,
check_finite
=
self
.
check_finite
)
.
astype
(
x
.
dtype
)
# Scipy cholesky only makes use of overwrite_a when it is F_CONTIGUOUS
# If we have a `C_CONTIGUOUS` array we transpose to benefit from it
if
self
.
overwrite_a
and
x
.
flags
[
"C_CONTIGUOUS"
]:
out
[
0
]
=
scipy
.
linalg
.
cholesky
(
x
.
T
,
lower
=
not
self
.
lower
,
check_finite
=
self
.
check_finite
,
overwrite_a
=
True
,
)
.
T
else
:
out
[
0
]
=
scipy
.
linalg
.
cholesky
(
x
,
lower
=
self
.
lower
,
check_finite
=
self
.
check_finite
,
overwrite_a
=
self
.
overwrite_a
,
)
except
scipy
.
linalg
.
LinAlgError
:
if
self
.
on_error
==
"raise"
:
raise
else
:
z
[
0
]
=
(
np
.
zeros
(
x
.
shape
)
*
np
.
nan
)
.
astype
(
x
.
dtype
)
out
[
0
]
=
np
.
full
(
x
.
shape
,
np
.
nan
,
dtype
=
node
.
outputs
[
0
]
.
type
.
dtype
)
def
L_op
(
self
,
inputs
,
outputs
,
gradients
):
"""
...
...
@@ -131,11 +142,66 @@ class Cholesky(Op):
else
:
return
[
grad
]
def
inplace_on_inputs
(
self
,
allowed_inplace_inputs
:
list
[
int
])
->
"Op"
:
if
not
allowed_inplace_inputs
:
return
self
new_props
=
self
.
_props_dict
()
# type: ignore
new_props
[
"overwrite_a"
]
=
True
return
type
(
self
)(
**
new_props
)
def
cholesky
(
x
,
lower
=
True
,
on_error
=
"raise"
,
check_finite
=
False
):
return
Blockwise
(
Cholesky
(
lower
=
lower
,
on_error
=
on_error
,
check_finite
=
check_finite
)
)(
x
)
def
cholesky
(
x
:
"TensorLike"
,
lower
:
bool
=
True
,
*
,
check_finite
:
bool
=
False
,
overwrite_a
:
bool
=
False
,
on_error
:
Literal
[
"raise"
,
"nan"
]
=
"raise"
,
):
"""
Return a triangular matrix square root of positive semi-definite `x`.
L = cholesky(X, lower=True) implies dot(L, L.T) == X.
Parameters
----------
x: tensor_like
lower : bool, default=True
Whether to return the lower or upper cholesky factor
check_finite : bool, default=False
Whether to check that the input matrix contains only finite numbers.
overwrite_a: bool, ignored
Whether to use the same memory for the output as `a`. This argument is ignored, and is present here only
for consistency with scipy.linalg.cholesky.
on_error : ['raise', 'nan']
If on_error is set to 'raise', this Op will raise a `scipy.linalg.LinAlgError` if the matrix is not positive definite.
If on_error is set to 'nan', it will return a matrix containing nans instead.
Returns
-------
TensorVariable
Lower or upper triangular Cholesky factor of `x`
Example
-------
.. testcode::
import pytensor
import pytensor.tensor as pt
import numpy as np
x = pt.tensor('x', shape=(5, 5), dtype='float64')
L = pt.linalg.cholesky(x)
f = pytensor.function([x], L)
x_value = np.random.normal(size=(5, 5))
x_value = x_value @ x_value.T # Ensures x is positive definite
L_value = f(x_value)
assert np.allclose(L_value @ L_value.T, x_value)
"""
return
Blockwise
(
Cholesky
(
lower
=
lower
,
on_error
=
on_error
))(
x
)
class
SolveBase
(
Op
):
...
...
@@ -145,6 +211,8 @@ class SolveBase(Op):
"lower"
,
"check_finite"
,
"b_ndim"
,
"overwrite_a"
,
"overwrite_b"
,
)
def
__init__
(
...
...
@@ -153,6 +221,8 @@ class SolveBase(Op):
lower
=
False
,
check_finite
=
True
,
b_ndim
,
overwrite_a
=
False
,
overwrite_b
=
False
,
):
self
.
lower
=
lower
self
.
check_finite
=
check_finite
...
...
@@ -162,9 +232,25 @@ class SolveBase(Op):
self
.
gufunc_signature
=
"(m,m),(m)->(m)"
else
:
self
.
gufunc_signature
=
"(m,m),(m,n)->(m,n)"
self
.
overwrite_a
=
overwrite_a
self
.
overwrite_b
=
overwrite_b
destroy_map
=
{}
if
self
.
overwrite_a
and
self
.
overwrite_b
:
# An output destroying two inputs is not yet supported
# destroy_map[0] = [0, 1]
raise
NotImplementedError
(
"It's not yet possible to overwrite_a and overwrite_b simultaneously"
)
elif
self
.
overwrite_a
:
destroy_map
[
0
]
=
[
0
]
elif
self
.
overwrite_b
:
destroy_map
[
0
]
=
[
1
]
self
.
destroy_map
=
destroy_map
def
perform
(
self
,
node
,
inputs
,
outputs
):
pass
raise
NotImplementedError
(
"SolveBase should be subclassed with an perform method"
)
def
make_node
(
self
,
A
,
b
):
A
=
as_tensor_variable
(
A
)
...
...
@@ -235,7 +321,16 @@ def _default_b_ndim(b, b_ndim):
class
CholeskySolve
(
SolveBase
):
__props__
=
(
"lower"
,
"check_finite"
,
"b_ndim"
,
"overwrite_b"
,
)
def
__init__
(
self
,
**
kwargs
):
if
kwargs
.
get
(
"overwrite_a"
,
False
):
raise
ValueError
(
"overwrite_a is not supported for CholeskySolve"
)
kwargs
.
setdefault
(
"lower"
,
True
)
super
()
.
__init__
(
**
kwargs
)
...
...
@@ -245,13 +340,23 @@ class CholeskySolve(SolveBase):
(
C
,
self
.
lower
),
b
,
check_finite
=
self
.
check_finite
,
overwrite_b
=
self
.
overwrite_b
,
)
output_storage
[
0
][
0
]
=
rval
def
L_op
(
self
,
*
args
,
**
kwargs
):
# TODO: Base impl should work, let's try it
raise
NotImplementedError
()
def
inplace_on_inputs
(
self
,
allowed_inplace_inputs
:
list
[
int
])
->
"Op"
:
if
1
in
allowed_inplace_inputs
:
new_props
=
self
.
_props_dict
()
# type: ignore
new_props
[
"overwrite_b"
]
=
True
return
type
(
self
)(
**
new_props
)
else
:
return
self
def
cho_solve
(
c_and_lower
,
b
,
*
,
check_finite
=
True
,
b_ndim
:
int
|
None
=
None
):
"""Solve the linear equations A x = b, given the Cholesky factorization of A.
...
...
@@ -286,9 +391,12 @@ class SolveTriangular(SolveBase):
"lower"
,
"check_finite"
,
"b_ndim"
,
"overwrite_b"
,
)
def
__init__
(
self
,
*
,
trans
=
0
,
unit_diagonal
=
False
,
**
kwargs
):
if
kwargs
.
get
(
"overwrite_a"
,
False
):
raise
ValueError
(
"overwrite_a is not supported for SolverTriangulare"
)
super
()
.
__init__
(
**
kwargs
)
self
.
trans
=
trans
self
.
unit_diagonal
=
unit_diagonal
...
...
@@ -302,6 +410,7 @@ class SolveTriangular(SolveBase):
trans
=
self
.
trans
,
unit_diagonal
=
self
.
unit_diagonal
,
check_finite
=
self
.
check_finite
,
overwrite_b
=
self
.
overwrite_b
,
)
def
L_op
(
self
,
inputs
,
outputs
,
output_gradients
):
...
...
@@ -314,6 +423,14 @@ class SolveTriangular(SolveBase):
return
res
def
inplace_on_inputs
(
self
,
allowed_inplace_inputs
:
list
[
int
])
->
"Op"
:
if
1
in
allowed_inplace_inputs
:
new_props
=
self
.
_props_dict
()
# type: ignore
new_props
[
"overwrite_b"
]
=
True
return
type
(
self
)(
**
new_props
)
else
:
return
self
def
solve_triangular
(
a
:
TensorVariable
,
...
...
@@ -374,6 +491,8 @@ class Solve(SolveBase):
"lower"
,
"check_finite"
,
"b_ndim"
,
"overwrite_a"
,
"overwrite_b"
,
)
def
__init__
(
self
,
*
,
assume_a
=
"gen"
,
**
kwargs
):
...
...
@@ -391,8 +510,24 @@ class Solve(SolveBase):
lower
=
self
.
lower
,
check_finite
=
self
.
check_finite
,
assume_a
=
self
.
assume_a
,
overwrite_a
=
self
.
overwrite_a
,
overwrite_b
=
self
.
overwrite_b
,
)
def
inplace_on_inputs
(
self
,
allowed_inplace_inputs
:
list
[
int
])
->
"Op"
:
if
not
allowed_inplace_inputs
:
return
self
new_props
=
self
.
_props_dict
()
# type: ignore
# PyTensor doesn't allow an output to destroy two inputs yet
# new_props["overwrite_a"] = 0 in allowed_inplace_inputs
# new_props["overwrite_b"] = 1 in allowed_inplace_inputs
if
1
in
allowed_inplace_inputs
:
# Give preference to overwrite_b
new_props
[
"overwrite_b"
]
=
True
else
:
# allowed inputs == [0]
new_props
[
"overwrite_a"
]
=
True
return
type
(
self
)(
**
new_props
)
def
solve
(
a
,
...
...
tests/tensor/test_blockwise.py
浏览文件 @
ed6ca162
...
...
@@ -3,10 +3,11 @@ from itertools import product
import
numpy
as
np
import
pytest
import
scipy.linalg
import
pytensor
from
pytensor
import
config
,
function
from
pytensor.compile
import
get_mode
from
pytensor
import
In
,
config
,
function
from
pytensor.compile
import
get_
default_mode
,
get_
mode
from
pytensor.gradient
import
grad
from
pytensor.graph
import
Apply
,
Op
from
pytensor.graph.replace
import
vectorize_node
...
...
@@ -15,7 +16,15 @@ from pytensor.tensor import diagonal, log, tensor
from
pytensor.tensor.blockwise
import
Blockwise
,
vectorize_node_fallback
from
pytensor.tensor.nlinalg
import
MatrixInverse
from
pytensor.tensor.rewriting.blas
import
specialize_matmul_to_batched_dot
from
pytensor.tensor.slinalg
import
Cholesky
,
Solve
,
cholesky
,
solve_triangular
from
pytensor.tensor.slinalg
import
(
Cholesky
,
Solve
,
SolveBase
,
cho_solve
,
cholesky
,
solve
,
solve_triangular
,
)
from
pytensor.tensor.utils
import
_parse_gufunc_signature
...
...
@@ -398,3 +407,105 @@ def test_cop_with_params():
with
pytest
.
raises
(
AssertionError
):
fn
(
np
.
zeros
((
5
,
3
,
2
))
-
1
)
@pytest.mark.skipif
(
config
.
mode
==
"FAST_COMPILE"
,
reason
=
"inplace rewrites disabled when mode is FAST_COMPILE"
,
)
class
TestInplace
:
@pytest.mark.parametrize
(
"is_batched"
,
(
False
,
True
))
def
test_cholesky
(
self
,
is_batched
):
X
=
tensor
(
"X"
,
shape
=
(
5
,
None
,
None
)
if
is_batched
else
(
None
,
None
))
L
=
cholesky
(
X
,
lower
=
True
)
f
=
function
([
In
(
X
,
mutable
=
True
)],
L
)
assert
not
L
.
owner
.
op
.
core_op
.
destroy_map
if
is_batched
:
[
cholesky_op
]
=
[
node
.
op
.
core_op
for
node
in
f
.
maker
.
fgraph
.
apply_nodes
if
isinstance
(
node
.
op
,
Blockwise
)
and
isinstance
(
node
.
op
.
core_op
,
Cholesky
)
]
else
:
[
cholesky_op
]
=
[
node
.
op
for
node
in
f
.
maker
.
fgraph
.
apply_nodes
if
isinstance
(
node
.
op
,
Cholesky
)
]
assert
cholesky_op
.
destroy_map
==
{
0
:
[
0
]}
rng
=
np
.
random
.
default_rng
(
441
+
is_batched
)
X_val
=
rng
.
normal
(
size
=
(
10
,
10
))
.
astype
(
config
.
floatX
)
X_val_in
=
X_val
@
X_val
.
T
if
is_batched
:
X_val_in
=
np
.
broadcast_to
(
X_val_in
,
(
5
,
*
X_val_in
.
shape
))
.
copy
()
X_val_in_copy
=
X_val_in
.
copy
()
f
(
X_val_in
)
np
.
testing
.
assert_allclose
(
X_val_in
,
np
.
linalg
.
cholesky
(
X_val_in_copy
),
atol
=
1e-5
if
config
.
floatX
==
"float32"
else
0
,
)
@pytest.mark.parametrize
(
"batched_A"
,
(
False
,
True
))
@pytest.mark.parametrize
(
"batched_b"
,
(
False
,
True
))
@pytest.mark.parametrize
(
"solve_fn"
,
(
solve
,
solve_triangular
,
cho_solve
))
def
test_solve
(
self
,
solve_fn
,
batched_A
,
batched_b
):
A
=
tensor
(
"A"
,
shape
=
(
5
,
3
,
3
)
if
batched_A
else
(
3
,
3
))
b
=
tensor
(
"b"
,
shape
=
(
5
,
3
)
if
batched_b
else
(
3
,))
if
solve_fn
==
cho_solve
:
# Special signature for cho_solve
x
=
solve_fn
((
A
,
True
),
b
,
b_ndim
=
1
)
else
:
x
=
solve_fn
(
A
,
b
,
b_ndim
=
1
)
mode
=
get_default_mode
()
.
excluding
(
"batched_vector_b_solve_to_matrix_b_solve"
)
fn
=
function
([
In
(
A
,
mutable
=
True
),
In
(
b
,
mutable
=
True
)],
x
,
mode
=
mode
)
op
=
fn
.
maker
.
fgraph
.
outputs
[
0
]
.
owner
.
op
if
batched_A
or
batched_b
:
assert
isinstance
(
op
,
Blockwise
)
and
isinstance
(
op
.
core_op
,
SolveBase
)
if
batched_A
and
not
batched_b
:
if
solve_fn
==
solve
:
assert
op
.
destroy_map
==
{
0
:
[
0
]}
else
:
# SolveTriangular does not destroy A
assert
op
.
destroy_map
==
{}
else
:
assert
op
.
destroy_map
==
{
0
:
[
1
]}
else
:
assert
isinstance
(
op
,
SolveBase
)
assert
op
.
destroy_map
==
{
0
:
[
1
]}
# We test with an F_CONTIGUOUS (core) A as only that will be destroyed by scipy
rng
=
np
.
random
.
default_rng
(
487
+
batched_A
+
2
*
batched_b
+
sum
(
map
(
ord
,
solve_fn
.
__name__
))
)
A_val
=
np
.
swapaxes
(
rng
.
normal
(
size
=
A
.
type
.
shape
)
.
astype
(
A
.
type
.
dtype
),
-
1
,
-
2
)
b_val
=
np
.
random
.
normal
(
size
=
b
.
type
.
shape
)
.
astype
(
b
.
type
.
dtype
)
A_val_copy
=
A_val
.
copy
()
b_val_copy
=
b_val
.
copy
()
out
=
fn
(
A_val
,
b_val
)
if
solve_fn
==
cho_solve
:
def
core_scipy_fn
(
A
,
b
):
return
scipy
.
linalg
.
cho_solve
((
A
,
True
),
b
)
else
:
core_scipy_fn
=
getattr
(
scipy
.
linalg
,
solve_fn
.
__name__
)
expected_out
=
np
.
vectorize
(
core_scipy_fn
,
signature
=
"(m,m),(m)->(m)"
)(
A_val_copy
,
b_val_copy
)
np
.
testing
.
assert_allclose
(
out
,
expected_out
,
atol
=
1e-6
if
config
.
floatX
==
"float32"
else
0
)
# Confirm input was destroyed
assert
(
A_val
==
A_val_copy
)
.
all
()
==
(
op
.
destroy_map
.
get
(
0
,
None
)
!=
[
0
])
assert
(
b_val
==
b_val_copy
)
.
all
()
==
(
op
.
destroy_map
.
get
(
0
,
None
)
!=
[
1
])
tests/tensor/test_slinalg.py
浏览文件 @
ed6ca162
...
...
@@ -197,7 +197,10 @@ class TestSolveBase(utt.InferShapeTester):
A
=
matrix
()
b
=
matrix
()
y
=
SolveBase
(
b_ndim
=
2
)(
A
,
b
)
assert
y
.
__repr__
()
==
"SolveBase{lower=False, check_finite=True, b_ndim=2}.0"
assert
(
y
.
__repr__
()
==
"SolveBase{lower=False, check_finite=True, b_ndim=2, overwrite_a=False, overwrite_b=False}.0"
)
class
TestSolve
(
utt
.
InferShapeTester
):
...
...
@@ -361,7 +364,7 @@ class TestCholeskySolve(utt.InferShapeTester):
def
test_repr
(
self
):
assert
(
repr
(
CholeskySolve
(
lower
=
True
,
b_ndim
=
1
))
==
"CholeskySolve(lower=True,check_finite=True,b_ndim=1)"
==
"CholeskySolve(lower=True,check_finite=True,b_ndim=1
,overwrite_b=False
)"
)
def
test_infer_shape
(
self
):
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论