提交 d4b8260c authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Eig impls: Keep complex outputs in Blockwise and Numba

上级 09cdd75a
......@@ -81,9 +81,11 @@ def numba_funcify_Eig(op, node, **kwargs):
@numba_basic.numba_njit
def eig(x):
return np.linalg.eig(inputs_cast(x))
w, v = np.linalg.eig(inputs_cast(x))
return w.astype(w_dtype), v.astype(w_dtype)
return eig
cache_version = 1
return eig, cache_version
@register_funcify_default_op_cache_key(Eigh)
......
......@@ -325,7 +325,8 @@ class Eig(Op):
"""
__props__: tuple[str, ...] = ()
gufunc_spec = ("numpy.linalg.eig", 1, 2)
# Can't use numpy directly in Blockwise, because of the dynamic dtype
# gufunc_spec = ("numpy.linalg.eig", 1, 2)
gufunc_signature = "(m,m)->(m),(m,m)"
def make_node(self, x):
......
......@@ -24,7 +24,7 @@ from pytensor.tensor import (
vector,
)
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
from pytensor.tensor.nlinalg import MatrixInverse
from pytensor.tensor.nlinalg import MatrixInverse, eig
from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot
from pytensor.tensor.signal import convolve1d
from pytensor.tensor.slinalg import (
......@@ -763,3 +763,22 @@ def test_partial_inplace():
add_supervisor_to_fgraph(fgraph, [In(inp, mutable=True) for inp in fgraph.inputs])
rewrite_graph(fgraph, include=("inplace",))
assert fgraph.outputs[0].owner.op.destroy_map == {1: [1]}
def test_eig_blockwise():
x = tensor("x", shape=(2, 3, 3), dtype="float64")
eigen_values, eigen_vectors = eig(x)
assert eigen_values.dtype == "complex128"
assert eigen_vectors.dtype == "complex128"
fn = function([x], [eigen_values, eigen_vectors])
eigen_values_res, eigen_vectors_res = fn(np.full((2, 3, 3), np.eye(3)))
np.testing.assert_allclose(
eigen_values_res,
np.ones((2, 3), dtype="complex128"),
strict=True,
)
np.testing.assert_allclose(
eigen_vectors_res,
np.full((2, 3, 3), np.eye(3), dtype="complex128"),
strict=True,
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论