提交 c5b96d92 authored 作者: lucianopaz's avatar lucianopaz 提交者: Ricardo Vieira

Don't store fortran objects in ScipyGer tag

上级 0ebc83b3
...@@ -144,20 +144,20 @@ except ImportError as e: ...@@ -144,20 +144,20 @@ except ImportError as e:
# If check_init_y() == True we need to initialize y when beta == 0. # If check_init_y() == True we need to initialize y when beta == 0.
def check_init_y(): def check_init_y():
if check_init_y._result is None: if check_init_y._result is None:
if not have_fblas: if not have_fblas: # pragma: no cover
check_init_y._result = False check_init_y._result = False
else:
y = float("NaN") * np.ones((2,)) y = float("NaN") * np.ones((2,))
x = np.ones((2,)) x = np.ones((2,))
A = np.ones((2, 2)) A = np.ones((2, 2))
gemv = _blas_gemv_fns[y.dtype] gemv = _blas_gemv_fns[y.dtype]
gemv(1.0, A.T, x, 0.0, y, overwrite_y=True, trans=True) gemv(1.0, A.T, x, 0.0, y, overwrite_y=True, trans=True)
check_init_y._result = np.isnan(y).any() check_init_y._result = np.isnan(y).any()
return check_init_y._result return check_init_y._result
check_init_y._result = None check_init_y._result = None # type: ignore
class Gemv(Op): class Gemv(Op):
......
...@@ -19,17 +19,13 @@ if have_fblas: ...@@ -19,17 +19,13 @@ if have_fblas:
class ScipyGer(Ger): class ScipyGer(Ger):
def prepare_node(self, node, storage_map, compute_map, impl):
if impl == "py":
node.tag.local_ger = _blas_ger_fns[np.dtype(node.inputs[0].type.dtype)]
def perform(self, node, inputs, output_storage): def perform(self, node, inputs, output_storage):
cA, calpha, cx, cy = inputs cA, calpha, cx, cy = inputs
(cZ,) = output_storage (cZ,) = output_storage
# N.B. some versions of scipy (e.g. mine) don't actually work # N.B. some versions of scipy (e.g. mine) don't actually work
# in-place on a, even when I tell it to. # in-place on a, even when I tell it to.
A = cA A = cA
local_ger = node.tag.local_ger local_ger = _blas_ger_fns[cA.dtype]
if A.size == 0: if A.size == 0:
# We don't have to compute anything, A is empty. # We don't have to compute anything, A is empty.
# We need this special case because Numpy considers it # We need this special case because Numpy considers it
......
...@@ -17,7 +17,6 @@ pytensor/scalar/basic.py ...@@ -17,7 +17,6 @@ pytensor/scalar/basic.py
pytensor/sparse/basic.py pytensor/sparse/basic.py
pytensor/sparse/type.py pytensor/sparse/type.py
pytensor/tensor/basic.py pytensor/tensor/basic.py
pytensor/tensor/blas.py
pytensor/tensor/blas_c.py pytensor/tensor/blas_c.py
pytensor/tensor/blas_headers.py pytensor/tensor/blas_headers.py
pytensor/tensor/elemwise.py pytensor/tensor/elemwise.py
...@@ -31,4 +30,4 @@ pytensor/tensor/slinalg.py ...@@ -31,4 +30,4 @@ pytensor/tensor/slinalg.py
pytensor/tensor/subtensor.py pytensor/tensor/subtensor.py
pytensor/tensor/type.py pytensor/tensor/type.py
pytensor/tensor/type_other.py pytensor/tensor/type_other.py
pytensor/tensor/variable.py pytensor/tensor/variable.py
\ No newline at end of file
import pickle
import numpy as np import numpy as np
import pytest import pytest
...@@ -58,6 +60,17 @@ class TestScipyGer(OptimizationTestMixin): ...@@ -58,6 +60,17 @@ class TestScipyGer(OptimizationTestMixin):
self.assertFunctionContains(f, gemm_no_inplace) self.assertFunctionContains(f, gemm_no_inplace)
self.run_f(f) # DebugMode tests correctness self.run_f(f) # DebugMode tests correctness
def test_pickle(self):
out = ScipyGer(destructive=False)(self.A, self.a, self.x, self.y)
f = pytensor.function([self.A, self.a, self.x, self.y], out)
new_f = pickle.loads(pickle.dumps(f))
assert isinstance(new_f.maker.fgraph.toposort()[-1].op, ScipyGer)
assert np.allclose(
f(self.Aval, 1.0, self.xval, self.yval),
new_f(self.Aval, 1.0, self.xval, self.yval),
)
class TestBlasStridesScipy(TestBlasStrides): class TestBlasStridesScipy(TestBlasStrides):
mode = pytensor.compile.get_default_mode() mode = pytensor.compile.get_default_mode()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论