提交 e0a2a865 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Remove predefined inplace Elemwise Ops and redundant tests

上级 42e8490c
...@@ -84,13 +84,13 @@ jobs: ...@@ -84,13 +84,13 @@ jobs:
install-mlx: [0] install-mlx: [0]
install-xarray: [0] install-xarray: [0]
part: part:
- "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/xtensor" - "tests --ignore=tests/scan --ignore=tests/tensor --ignore=tests/xtensor"
- "tests/scan" - "tests/scan"
- "tests/tensor --ignore=tests/tensor/rewriting --ignore=tests/tensor/test_math.py --ignore=tests/tensor/test_basic.py --ignore=tests/tensor/test_inplace.py --ignore=tests/tensor/conv --ignore=tests/tensor/test_blas.py --ignore=tests/tensor/test_elemwise.py --ignore=tests/tensor/test_math_scipy.py" - "tests/tensor --ignore=tests/tensor/test_basic.py --ignore=tests/tensor/test_elemwise.py --ignore=tests/tensor/test_math.py --ignore=tests/tensor/test_math_scipy.py --ignore=tests/tensor/test_blas.py --ignore=tests/tensor/conv --ignore=tests/tensor/rewriting"
- "tests/tensor/rewriting" - "tests/tensor/test_basic.py tests/tensor/test_elemwise.py"
- "tests/tensor/test_math.py" - "tests/tensor/test_math.py"
- "tests/tensor/test_basic.py tests/tensor/test_inplace.py tests/tensor/conv" - "tests/tensor/test_math_scipy.py tests/tensor/test_blas.py tests/tensor/conv"
- "tests/tensor/test_blas.py tests/tensor/test_elemwise.py tests/tensor/test_math_scipy.py" - "tests/tensor/rewriting"
exclude: exclude:
- python-version: "3.11" - python-version: "3.11"
fast-compile: 1 fast-compile: 1
...@@ -167,7 +167,7 @@ jobs: ...@@ -167,7 +167,7 @@ jobs:
install-numba: 0 install-numba: 0
install-jax: 0 install-jax: 0
install-torch: 0 install-torch: 0
part: "tests/tensor/test_blas.py tests/tensor/test_elemwise.py tests/tensor/test_math_scipy.py" part: "tests/tensor/test_elemwise.py tests/tensor/test_math_scipy.py tests/tensor/test_blas.py"
steps: steps:
- uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
......
...@@ -20,7 +20,7 @@ from pytensor.misc.frozendict import frozendict ...@@ -20,7 +20,7 @@ from pytensor.misc.frozendict import frozendict
from pytensor.printing import Printer, pprint from pytensor.printing import Printer, pprint
from pytensor.scalar import get_scalar_type from pytensor.scalar import get_scalar_type
from pytensor.scalar.basic import identity as scalar_identity from pytensor.scalar.basic import identity as scalar_identity
from pytensor.scalar.basic import int64, transfer_type, upcast from pytensor.scalar.basic import int64, upcast
from pytensor.tensor import elemwise_cgen as cgen from pytensor.tensor import elemwise_cgen as cgen
from pytensor.tensor import get_vector_length from pytensor.tensor import get_vector_length
from pytensor.tensor.basic import _get_vector_length, as_tensor_variable from pytensor.tensor.basic import _get_vector_length, as_tensor_variable
...@@ -1634,17 +1634,12 @@ def scalar_elemwise(*symbol, nfunc=None, nin=None, nout=None, symbolname=None): ...@@ -1634,17 +1634,12 @@ def scalar_elemwise(*symbol, nfunc=None, nin=None, nout=None, symbolname=None):
symbolname = symbolname or symbol.__name__ symbolname = symbolname or symbol.__name__
if symbolname.endswith("_inplace"): if symbolname.endswith("_inplace"):
base_symbol_name = symbolname[: -len("_inplace")] raise ValueError(
scalar_op = getattr(scalar, base_symbol_name) "Creation of automatic inplace elemwise operations deprecated"
inplace_scalar_op = scalar_op.__class__(transfer_type(0))
rval = Elemwise(
inplace_scalar_op,
{0: 0},
nfunc_spec=(nfunc and (nfunc, nin, nout)),
) )
else:
scalar_op = getattr(scalar, symbolname) scalar_op = getattr(scalar, symbolname)
rval = Elemwise(scalar_op, nfunc_spec=(nfunc and (nfunc, nin, nout))) rval = Elemwise(scalar_op, nfunc_spec=(nfunc and (nfunc, nin, nout)))
if getattr(symbol, "__doc__"): if getattr(symbol, "__doc__"):
rval.__doc__ = symbol.__doc__ rval.__doc__ = symbol.__doc__
......
from pytensor import printing
from pytensor.printing import pprint
from pytensor.tensor.elemwise import scalar_elemwise
@scalar_elemwise
def lt_inplace(a, b):
"""a < b (inplace on a)"""
@scalar_elemwise
def gt_inplace(a, b):
"""a > b (inplace on a)"""
@scalar_elemwise
def le_inplace(a, b):
"""a <= b (inplace on a)"""
@scalar_elemwise
def ge_inplace(a, b):
"""a >= b (inplace on a)"""
@scalar_elemwise
def eq_inplace(a, b):
"""a == b (inplace on a)"""
@scalar_elemwise
def neq_inplace(a, b):
"""a != b (inplace on a)"""
@scalar_elemwise
def and__inplace(a, b):
"""bitwise a & b (inplace on a)"""
@scalar_elemwise
def or__inplace(a, b):
"""bitwise a | b (inplace on a)"""
@scalar_elemwise
def xor_inplace(a, b):
"""bitwise a ^ b (inplace on a)"""
@scalar_elemwise
def invert_inplace(a):
"""bitwise ~a (inplace on a)"""
@scalar_elemwise
def abs_inplace(a):
"""|`a`| (inplace on `a`)"""
@scalar_elemwise
def exp_inplace(a):
"""e^`a` (inplace on `a`)"""
@scalar_elemwise
def exp2_inplace(a):
"""2^`a` (inplace on `a`)"""
@scalar_elemwise
def expm1_inplace(a):
"""e^`a` - 1 (inplace on `a`)"""
@scalar_elemwise
def neg_inplace(a):
"""-a (inplace on a)"""
@scalar_elemwise
def reciprocal_inplace(a):
"""1.0/a (inplace on a)"""
@scalar_elemwise
def log_inplace(a):
"""base e logarithm of a (inplace on a)"""
@scalar_elemwise
def log1p_inplace(a):
"""log(1+a)"""
@scalar_elemwise
def log2_inplace(a):
"""base 2 logarithm of a (inplace on a)"""
@scalar_elemwise
def log10_inplace(a):
"""base 10 logarithm of a (inplace on a)"""
@scalar_elemwise
def sign_inplace(a):
"""sign of `a` (inplace on `a`)"""
@scalar_elemwise
def ceil_inplace(a):
"""ceil of `a` (inplace on `a`)"""
@scalar_elemwise
def floor_inplace(a):
"""floor of `a` (inplace on `a`)"""
@scalar_elemwise
def trunc_inplace(a):
"""trunc of `a` (inplace on `a`)"""
@scalar_elemwise
def round_half_to_even_inplace(a):
"""round_half_to_even_inplace(a) (inplace on `a`)"""
@scalar_elemwise
def round_half_away_from_zero_inplace(a):
"""round_half_away_from_zero_inplace(a) (inplace on `a`)"""
@scalar_elemwise
def sqr_inplace(a):
"""square of `a` (inplace on `a`)"""
@scalar_elemwise
def sqrt_inplace(a):
"""square root of `a` (inplace on `a`)"""
@scalar_elemwise
def deg2rad_inplace(a):
"""convert degree `a` to radian(inplace on `a`)"""
@scalar_elemwise
def rad2deg_inplace(a):
"""convert radian `a` to degree(inplace on `a`)"""
@scalar_elemwise
def cos_inplace(a):
"""cosine of `a` (inplace on `a`)"""
@scalar_elemwise
def arccos_inplace(a):
"""arccosine of `a` (inplace on `a`)"""
@scalar_elemwise
def sin_inplace(a):
"""sine of `a` (inplace on `a`)"""
@scalar_elemwise
def arcsin_inplace(a):
"""arcsine of `a` (inplace on `a`)"""
@scalar_elemwise
def tan_inplace(a):
"""tangent of `a` (inplace on `a`)"""
@scalar_elemwise
def arctan_inplace(a):
"""arctangent of `a` (inplace on `a`)"""
@scalar_elemwise
def arctan2_inplace(a, b):
"""arctangent of `a` / `b` (inplace on `a`)"""
@scalar_elemwise
def cosh_inplace(a):
"""hyperbolic cosine of `a` (inplace on `a`)"""
@scalar_elemwise
def arccosh_inplace(a):
"""hyperbolic arc cosine of `a` (inplace on `a`)"""
@scalar_elemwise
def sinh_inplace(a):
"""hyperbolic sine of `a` (inplace on `a`)"""
@scalar_elemwise
def arcsinh_inplace(a):
"""hyperbolic arc sine of `a` (inplace on `a`)"""
@scalar_elemwise
def tanh_inplace(a):
"""hyperbolic tangent of `a` (inplace on `a`)"""
@scalar_elemwise
def arctanh_inplace(a):
"""hyperbolic arc tangent of `a` (inplace on `a`)"""
@scalar_elemwise
def erf_inplace(a):
"""error function"""
@scalar_elemwise
def erfc_inplace(a):
"""complementary error function"""
@scalar_elemwise
def erfcx_inplace(a):
"""scaled complementary error function"""
@scalar_elemwise
def owens_t_inplace(h, a):
"""owens t function"""
@scalar_elemwise
def gamma_inplace(a):
"""gamma function"""
@scalar_elemwise
def gammaln_inplace(a):
"""log gamma function"""
@scalar_elemwise
def psi_inplace(a):
"""derivative of log gamma function"""
@scalar_elemwise
def tri_gamma_inplace(a):
"""second derivative of the log gamma function"""
@scalar_elemwise
def gammainc_inplace(k, x):
"""regularized lower gamma function (P)"""
@scalar_elemwise
def gammaincc_inplace(k, x):
"""regularized upper gamma function (Q)"""
@scalar_elemwise
def gammau_inplace(k, x):
"""upper incomplete gamma function"""
@scalar_elemwise
def gammal_inplace(k, x):
"""lower incomplete gamma function"""
@scalar_elemwise
def gammaincinv_inplace(k, x):
"""Inverse to the regularized lower incomplete gamma function"""
@scalar_elemwise
def gammainccinv_inplace(k, x):
"""Inverse of the regularized upper incomplete gamma function"""
@scalar_elemwise
def j0_inplace(x):
"""Bessel function of the first kind of order 0."""
@scalar_elemwise
def j1_inplace(x):
"""Bessel function of the first kind of order 1."""
@scalar_elemwise
def jv_inplace(v, x):
"""Bessel function of the first kind of order v (real)."""
@scalar_elemwise
def i0_inplace(x):
"""Modified Bessel function of the first kind of order 0."""
@scalar_elemwise
def i1_inplace(x):
"""Modified Bessel function of the first kind of order 1."""
@scalar_elemwise
def iv_inplace(v, x):
"""Modified Bessel function of the first kind of order v (real)."""
@scalar_elemwise
def ive_inplace(v, x):
"""Exponentially scaled modified Bessel function of the first kind of order v (real)."""
@scalar_elemwise
def sigmoid_inplace(x):
"""Logistic sigmoid function (1 / (1 + exp(-x)), also known as expit or inverse logit"""
@scalar_elemwise
def softplus_inplace(x):
"""Compute log(1 + exp(x)), also known as softplus or log1pexp"""
@scalar_elemwise
def log1mexp_inplace(x):
"""Compute log(1 - exp(x)), also known as log1mexp"""
@scalar_elemwise
def betainc_inplace(a, b, x):
"""Regularized incomplete beta function"""
@scalar_elemwise
def betaincinv_inplace(a, b, x):
"""Inverse of the regularized incomplete beta function"""
@scalar_elemwise
def second_inplace(a):
"""Fill `a` with `b`"""
fill_inplace = second_inplace
pprint.assign(fill_inplace, printing.FunctionPrinter(["fill="]))
@scalar_elemwise
def maximum_inplace(a, b):
"""elementwise addition (inplace on `a`)"""
@scalar_elemwise
def minimum_inplace(a, b):
"""elementwise addition (inplace on `a`)"""
@scalar_elemwise
def add_inplace(a, b):
"""elementwise addition (inplace on `a`)"""
@scalar_elemwise
def sub_inplace(a, b):
"""elementwise subtraction (inplace on `a`)"""
@scalar_elemwise
def mul_inplace(a, b):
"""elementwise multiplication (inplace on `a`)"""
@scalar_elemwise
def true_div_inplace(a, b):
"""elementwise division (inplace on `a`)"""
@scalar_elemwise
def int_div_inplace(a, b):
"""elementwise division (inplace on `a`)"""
@scalar_elemwise
def mod_inplace(a, b):
"""elementwise modulo (inplace on `a`)"""
@scalar_elemwise
def pow_inplace(a, b):
"""elementwise power (inplace on `a`)"""
@scalar_elemwise
def conj_inplace(a):
"""elementwise conjugate (inplace on `a`)"""
@scalar_elemwise
def hyp2f1_inplace(a, b, c, z):
"""gaussian hypergeometric function"""
pprint.assign(add_inplace, printing.OperatorPrinter("+=", -2, "either"))
pprint.assign(mul_inplace, printing.OperatorPrinter("*=", -1, "either"))
pprint.assign(sub_inplace, printing.OperatorPrinter("-=", -2, "left"))
pprint.assign(neg_inplace, printing.OperatorPrinter("-=", 0, "either"))
pprint.assign(true_div_inplace, printing.OperatorPrinter("/=", -1, "left"))
pprint.assign(int_div_inplace, printing.OperatorPrinter("//=", -1, "left"))
pprint.assign(pow_inplace, printing.OperatorPrinter("**=", 1, "right"))
def transpose_inplace(x, **kwargs):
"Perform a transpose on a tensor without copying the underlying storage"
dims = list(range(x.ndim - 1, -1, -1))
return x.dimshuffle(dims)
...@@ -6,13 +6,13 @@ import scipy.special ...@@ -6,13 +6,13 @@ import scipy.special
import pytensor import pytensor
import pytensor.tensor as pt import pytensor.tensor as pt
import pytensor.tensor.inplace as pti
import pytensor.tensor.math as ptm import pytensor.tensor.math as ptm
from pytensor import config, function from pytensor import config, function
from pytensor.compile import get_mode from pytensor.compile import get_mode
from pytensor.compile.ops import deep_copy_op from pytensor.compile.ops import deep_copy_op
from pytensor.gradient import grad from pytensor.gradient import grad
from pytensor.scalar import Composite, float64 from pytensor.scalar import Composite, float64
from pytensor.scalar import add as scalar_add
from pytensor.tensor import blas, tensor from pytensor.tensor import blas, tensor
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.math import All, Any, Max, Min, Prod, ProdWithoutZeros, Sum from pytensor.tensor.math import All, Any, Max, Min, Prod, ProdWithoutZeros, Sum
...@@ -30,6 +30,8 @@ from tests.tensor.test_elemwise import ( ...@@ -30,6 +30,8 @@ from tests.tensor.test_elemwise import (
rng = np.random.default_rng(42849) rng = np.random.default_rng(42849)
add_inplace = Elemwise(scalar_add, {0: 0})
@pytest.mark.parametrize( @pytest.mark.parametrize(
"inputs, input_vals, output_fn", "inputs, input_vals, output_fn",
...@@ -80,7 +82,7 @@ rng = np.random.default_rng(42849) ...@@ -80,7 +82,7 @@ rng = np.random.default_rng(42849)
np.array(1.0, dtype=config.floatX), np.array(1.0, dtype=config.floatX),
np.array(1.0, dtype=config.floatX), np.array(1.0, dtype=config.floatX),
], ],
lambda x, y: pti.add_inplace(deep_copy_op(x), deep_copy_op(y)), lambda x, y: add_inplace(deep_copy_op(x), deep_copy_op(y)),
), ),
( (
[pt.vector(), pt.vector()], [pt.vector(), pt.vector()],
...@@ -88,7 +90,7 @@ rng = np.random.default_rng(42849) ...@@ -88,7 +90,7 @@ rng = np.random.default_rng(42849)
rng.standard_normal(100).astype(config.floatX), rng.standard_normal(100).astype(config.floatX),
rng.standard_normal(100).astype(config.floatX), rng.standard_normal(100).astype(config.floatX),
], ],
lambda x, y: pti.add_inplace(deep_copy_op(x), deep_copy_op(y)), lambda x, y: add_inplace(deep_copy_op(x), deep_copy_op(y)),
), ),
( (
[pt.vector(), pt.vector()], [pt.vector(), pt.vector()],
......
...@@ -31,7 +31,6 @@ from pytensor.graph.rewriting.utils import is_same_graph, rewrite_graph ...@@ -31,7 +31,6 @@ from pytensor.graph.rewriting.utils import is_same_graph, rewrite_graph
from pytensor.graph.traversal import ancestors from pytensor.graph.traversal import ancestors
from pytensor.printing import debugprint from pytensor.printing import debugprint
from pytensor.scalar import PolyGamma, Psi, TriGamma from pytensor.scalar import PolyGamma, Psi, TriGamma
from pytensor.tensor import inplace
from pytensor.tensor.basic import Alloc, constant, join, second, switch from pytensor.tensor.basic import Alloc, constant, join, second, switch
from pytensor.tensor.blas import Dot22, Gemv from pytensor.tensor.blas import Dot22, Gemv
from pytensor.tensor.blas_c import CGemv from pytensor.tensor.blas_c import CGemv
...@@ -1134,15 +1133,15 @@ def test_log1p(): ...@@ -1134,15 +1133,15 @@ def test_log1p():
f = function([x], log(1 + (x)), mode=m) f = function([x], log(1 + (x)), mode=m)
assert [node.op for node in f.maker.fgraph.toposort()] == [log1p] assert [node.op for node in f.maker.fgraph.toposort()] == [log1p]
f = function([x], log(1 + (-x)), mode=m) f = function([x], log(1 + (-x)), mode=m)
assert [node.op for node in f.maker.fgraph.toposort()] == [ assert [node.op.scalar_op for node in f.maker.fgraph.toposort()] == [
neg, ps.neg,
inplace.log1p_inplace, ps.log1p,
] ]
f = function([x], -log(1 + (-x)), mode=m) f = function([x], -log(1 + (-x)), mode=m)
assert [node.op for node in f.maker.fgraph.toposort()] == [ assert [node.op.scalar_op for node in f.maker.fgraph.toposort()] == [
neg, ps.neg,
inplace.log1p_inplace, ps.log1p,
inplace.neg_inplace, ps.neg,
] ]
# check trickier cases (and use different dtype) # check trickier cases (and use different dtype)
...@@ -4035,27 +4034,27 @@ class TestSigmoidRewrites: ...@@ -4035,27 +4034,27 @@ class TestSigmoidRewrites:
# todo: solve issue #4589 first # todo: solve issue #4589 first
# assert check_stack_trace( # assert check_stack_trace(
# f, ops_to_check=[sigmoid, neg_inplace]) # f, ops_to_check=[sigmoid, neg_inplace])
assert [node.op for node in f.maker.fgraph.toposort()] == [ assert [node.op.scalar_op for node in f.maker.fgraph.toposort()] == [
sigmoid, ps.sigmoid,
inplace.neg_inplace, ps.neg,
] ]
f(data) f(data)
f = pytensor.function([x], pt.fill(x, -1.0) / (1 - exp(-x)), mode=m) f = pytensor.function([x], pt.fill(x, -1.0) / (1 - exp(-x)), mode=m)
assert [node.op for node in f.maker.fgraph.toposort()] != [ assert [node.op.scalar_op for node in f.maker.fgraph.toposort()] != [
sigmoid, ps.sigmoid,
inplace.neg_inplace, ps.neg,
] ]
f(data) f(data)
f = pytensor.function([x], pt.fill(x, -1.0) / (2 + exp(-x)), mode=m) f = pytensor.function([x], pt.fill(x, -1.0) / (2 + exp(-x)), mode=m)
assert [node.op for node in f.maker.fgraph.toposort()] != [ assert [node.op.scalar_op for node in f.maker.fgraph.toposort()] != [
sigmoid, ps.sigmoid,
inplace.neg_inplace, ps.neg,
] ]
f(data) f(data)
f = pytensor.function([x], pt.fill(x, -1.1) / (1 + exp(-x)), mode=m) f = pytensor.function([x], pt.fill(x, -1.1) / (1 + exp(-x)), mode=m)
assert [node.op for node in f.maker.fgraph.toposort()] != [ assert [node.op.scalar_op for node in f.maker.fgraph.toposort()] != [
sigmoid, ps.sigmoid,
inplace.neg_inplace, ps.neg,
] ]
f(data) f(data)
...@@ -4077,10 +4076,10 @@ class TestSigmoidRewrites: ...@@ -4077,10 +4076,10 @@ class TestSigmoidRewrites:
(pt.fill(x, -1.1) * exp(x)) / ((1 + exp(x)) * (1 + exp(-x))), (pt.fill(x, -1.1) * exp(x)) / ((1 + exp(x)) * (1 + exp(-x))),
mode=m, mode=m,
) )
assert [node.op for node in f.maker.fgraph.toposort()] != [ assert [node.op.scalar_op for node in f.maker.fgraph.toposort()] != [
sigmoid, ps.sigmoid,
mul, ps.mul,
inplace.neg_inplace, ps.neg,
] ]
f(data) f(data)
f = pytensor.function( f = pytensor.function(
...@@ -4088,10 +4087,10 @@ class TestSigmoidRewrites: ...@@ -4088,10 +4087,10 @@ class TestSigmoidRewrites:
(pt.fill(x, -1.0) * exp(x)) / ((2 + exp(x)) * (1 + exp(-x))), (pt.fill(x, -1.0) * exp(x)) / ((2 + exp(x)) * (1 + exp(-x))),
mode=m, mode=m,
) )
assert [node.op for node in f.maker.fgraph.toposort()] != [ assert [node.op.scalar_op for node in f.maker.fgraph.toposort()] != [
sigmoid, ps.sigmoid,
mul, ps.mul,
inplace.neg_inplace, ps.neg,
] ]
f(data) f(data)
f = pytensor.function( f = pytensor.function(
...@@ -4099,10 +4098,10 @@ class TestSigmoidRewrites: ...@@ -4099,10 +4098,10 @@ class TestSigmoidRewrites:
(pt.fill(x, -1.0) * exp(x)) / ((1 + exp(x)) * (2 + exp(-x))), (pt.fill(x, -1.0) * exp(x)) / ((1 + exp(x)) * (2 + exp(-x))),
mode=m, mode=m,
) )
assert [node.op for node in f.maker.fgraph.toposort()] != [ assert [node.op.scalar_op for node in f.maker.fgraph.toposort()] != [
sigmoid, ps.sigmoid,
mul, ps.mul,
inplace.neg_inplace, ps.neg,
] ]
f(data) f(data)
f = pytensor.function( f = pytensor.function(
...@@ -4110,10 +4109,10 @@ class TestSigmoidRewrites: ...@@ -4110,10 +4109,10 @@ class TestSigmoidRewrites:
(pt.fill(x, -1.0) * exp(x)) / ((1 + exp(x)) * (1 + exp(x))), (pt.fill(x, -1.0) * exp(x)) / ((1 + exp(x)) * (1 + exp(x))),
mode=m, mode=m,
) )
assert [node.op for node in f.maker.fgraph.toposort()] != [ assert [node.op.scalar_op for node in f.maker.fgraph.toposort()] != [
sigmoid, ps.sigmoid,
mul, ps.mul,
inplace.neg_inplace, ps.neg,
] ]
f(data) f(data)
f = pytensor.function( f = pytensor.function(
...@@ -4121,10 +4120,10 @@ class TestSigmoidRewrites: ...@@ -4121,10 +4120,10 @@ class TestSigmoidRewrites:
(pt.fill(x, -1.0) * exp(x)) / ((1 + exp(x)) * (2 + exp(-x))), (pt.fill(x, -1.0) * exp(x)) / ((1 + exp(x)) * (2 + exp(-x))),
mode=m, mode=m,
) )
assert [node.op for node in f.maker.fgraph.toposort()] != [ assert [node.op.scalar_op for node in f.maker.fgraph.toposort()] != [
sigmoid, ps.sigmoid,
mul, ps.mul,
inplace.neg_inplace, ps.neg,
] ]
f(data) f(data)
......
...@@ -17,7 +17,6 @@ from pytensor.configdefaults import config ...@@ -17,7 +17,6 @@ from pytensor.configdefaults import config
from pytensor.gradient import grad from pytensor.gradient import grad
from pytensor.graph.rewriting.basic import in2out from pytensor.graph.rewriting.basic import in2out
from pytensor.graph.utils import InconsistencyError from pytensor.graph.utils import InconsistencyError
from pytensor.tensor import inplace
from pytensor.tensor.basic import as_tensor_variable from pytensor.tensor.basic import as_tensor_variable
from pytensor.tensor.blas import ( from pytensor.tensor.blas import (
BatchedDot, BatchedDot,
...@@ -40,6 +39,7 @@ from pytensor.tensor.blas import ( ...@@ -40,6 +39,7 @@ from pytensor.tensor.blas import (
ger, ger,
ger_destructive, ger_destructive,
) )
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import Dot, dot, mean, mul, outer, sigmoid from pytensor.tensor.math import Dot, dot, mean, mul, outer, sigmoid
from pytensor.tensor.rewriting.blas import local_dot22_to_dot22scalar, local_gemm_to_ger from pytensor.tensor.rewriting.blas import local_dot22_to_dot22scalar, local_gemm_to_ger
from pytensor.tensor.type import ( from pytensor.tensor.type import (
...@@ -258,16 +258,20 @@ class TestGemm: ...@@ -258,16 +258,20 @@ class TestGemm:
rng = np.random.default_rng(seed=utt.fetch_seed()) rng = np.random.default_rng(seed=utt.fetch_seed())
Z = as_tensor_variable(rng.random((2, 2))) Z = as_tensor_variable(rng.random((2, 2)))
A = as_tensor_variable(rng.random((2, 2))) A = as_tensor_variable(rng.random((2, 2)))
Zt = Z.transpose()
assert isinstance(Zt.owner.op, DimShuffle) and Zt.owner.op.view_map == {0: [0]}
with pytest.raises(InconsistencyError, match=Gemm.E_z_uniq): with pytest.raises(InconsistencyError, match=Gemm.E_z_uniq):
gemm_inplace(Z, 1.0, A, inplace.transpose_inplace(Z), 1.0) gemm_inplace(Z, 1.0, A, Zt, 1.0)
def test_destroy_map2(self): def test_destroy_map2(self):
# test that only first input can be overwritten. # test that only first input can be overwritten.
rng = np.random.default_rng(seed=utt.fetch_seed()) rng = np.random.default_rng(seed=utt.fetch_seed())
Z = as_tensor_variable(rng.random((2, 2))) Z = as_tensor_variable(rng.random((2, 2)))
A = as_tensor_variable(rng.random((2, 2))) A = as_tensor_variable(rng.random((2, 2)))
Zt = Z.transpose()
assert isinstance(Zt.owner.op, DimShuffle) and Zt.owner.op.view_map == {0: [0]}
with pytest.raises(InconsistencyError, match=Gemm.E_z_uniq): with pytest.raises(InconsistencyError, match=Gemm.E_z_uniq):
gemm_inplace(Z, 1.0, inplace.transpose_inplace(Z), A, 1.0) gemm_inplace(Z, 1.0, Zt, A, 1.0)
def test_destroy_map3(self): def test_destroy_map3(self):
# test that only first input can be overwritten # test that only first input can be overwritten
......
...@@ -20,6 +20,9 @@ from pytensor.graph.replace import vectorize_node ...@@ -20,6 +20,9 @@ from pytensor.graph.replace import vectorize_node
from pytensor.link.basic import PerformLinker from pytensor.link.basic import PerformLinker
from pytensor.link.c.basic import CLinker, OpWiseCLinker from pytensor.link.c.basic import CLinker, OpWiseCLinker
from pytensor.scalar import ScalarOp, float32, float64, int32, int64 from pytensor.scalar import ScalarOp, float32, float64, int32, int64
from pytensor.scalar import add as scalar_add
from pytensor.scalar import exp as scalar_exp
from pytensor.scalar import xor as scalar_xor
from pytensor.tensor import as_tensor_variable from pytensor.tensor import as_tensor_variable
from pytensor.tensor.basic import get_scalar_constant_value, second from pytensor.tensor.basic import get_scalar_constant_value, second
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
...@@ -43,6 +46,16 @@ from pytensor.tensor.type import ( ...@@ -43,6 +46,16 @@ from pytensor.tensor.type import (
) )
from tests import unittest_tools from tests import unittest_tools
from tests.link.test_link import make_function from tests.link.test_link import make_function
from tests.tensor.utils import (
_bad_runtime_broadcast_binary_normal,
inplace_func,
integers,
integers_uint16,
integers_uint32,
makeBroadcastTester,
random,
random_complex,
)
def reduce_bitwise_and(x, axis=-1, dtype="int8"): def reduce_bitwise_and(x, axis=-1, dtype="int8"):
...@@ -334,7 +347,7 @@ class TestBroadcast: ...@@ -334,7 +347,7 @@ class TestBroadcast:
x = x_type("x") x = x_type("x")
y = y_type("y") y = y_type("y")
e = op(ps.Add(ps.transfer_type(0)), {0: 0})(x, y) e = op(ps.add, {0: 0})(x, y)
f = make_function(copy(linker).accept(FunctionGraph([x, y], [e]))) f = make_function(copy(linker).accept(FunctionGraph([x, y], [e])))
xv = rand_val(xsh) xv = rand_val(xsh)
yv = rand_val(ysh) yv = rand_val(ysh)
...@@ -348,7 +361,7 @@ class TestBroadcast: ...@@ -348,7 +361,7 @@ class TestBroadcast:
if isinstance(linker, PerformLinker): if isinstance(linker, PerformLinker):
x = x_type("x") x = x_type("x")
y = y_type("y") y = y_type("y")
e = op(ps.Add(ps.transfer_type(0)), {0: 0})(x, y) e = op(ps.add, {0: 0})(x, y)
f = make_function(copy(linker).accept(FunctionGraph([x, y], [e.shape]))) f = make_function(copy(linker).accept(FunctionGraph([x, y], [e.shape])))
xv = rand_val(xsh) xv = rand_val(xsh)
yv = rand_val(ysh) yv = rand_val(ysh)
...@@ -390,7 +403,10 @@ class TestBroadcast: ...@@ -390,7 +403,10 @@ class TestBroadcast:
): ):
x = t(pytensor.config.floatX, shape=(None, None))("x") x = t(pytensor.config.floatX, shape=(None, None))("x")
y = t(pytensor.config.floatX, shape=(1, 1))("y") y = t(pytensor.config.floatX, shape=(1, 1))("y")
e = op(ps.Second(ps.transfer_type(0)), {0: 0})(x, y) op1 = op(ps.second, {0: 0})
op2 = op(ps.second, {0: 0})
assert op1 == op2
e = op(ps.Second(), {0: 0})(x, y)
f = make_function(linker().accept(FunctionGraph([x, y], [e]))) f = make_function(linker().accept(FunctionGraph([x, y], [e])))
xv = rval((5, 5)) xv = rval((5, 5))
yv = rval((1, 1)) yv = rval((1, 1))
...@@ -1113,3 +1129,74 @@ def test_numpy_warning_suppressed(): ...@@ -1113,3 +1129,74 @@ def test_numpy_warning_suppressed():
y = pt.log(x) y = pt.log(x)
fn = pytensor.function([x], y, mode=Mode(linker="py")) fn = pytensor.function([x], y, mode=Mode(linker="py"))
assert fn(0) == -np.inf assert fn(0) == -np.inf
rng = np.random.default_rng(18)
_good_add_inplace = dict(
same_shapes=(random(2, 3, rng=rng), random(2, 3, rng=rng)),
not_same_dimensions=(random(2, 2, rng=rng), random(2, rng=rng)),
scalar=(random(2, 3, rng=rng), random(1, 1, rng=rng)),
row=(random(2, 3, rng=rng), random(1, 3, rng=rng)),
column=(random(2, 3, rng=rng), random(2, 1, rng=rng)),
integers=(integers(2, 3, rng=rng), integers(2, 3, rng=rng)),
uint32=(integers_uint32(2, 3, rng=rng), integers_uint32(2, 3, rng=rng)),
uint16=(integers_uint16(2, 3, rng=rng), integers_uint16(2, 3, rng=rng)),
# (float32, >int16) upcasts to float64 by default
dtype_valid_mixup=(
random(2, 3, rng=rng),
integers(2, 3, rng=rng).astype(
"int16" if config.floatX == "float32" else "int64"
),
),
complex1=(random_complex(2, 3, rng=rng), random_complex(2, 3, rng=rng)),
complex2=(random_complex(2, 3, rng=rng), random(2, 3, rng=rng)),
empty=(np.asarray([], dtype=config.floatX), np.asarray([1], dtype=config.floatX)),
)
TestAddInplaceBroadcast = makeBroadcastTester(
op=Elemwise(scalar_add, {0: 0}),
expected=lambda x, y: x + y,
good=_good_add_inplace,
# Cannot inplace on first input if it doesn't match output dtype (upcast of inputs)
bad_build=dict(dtype_invalid_mixup=_good_add_inplace["dtype_valid_mixup"][::-1]),
bad_runtime=_bad_runtime_broadcast_binary_normal,
inplace=True,
)
@pytest.mark.xfail(
config.cycle_detection == "fast" and config.mode != "FAST_COMPILE",
reason="Cycle detection is fast and mode is FAST_COMPILE",
)
def test_exp_inplace_grad_1():
utt.verify_grad(
Elemwise(scalar_exp, {0: 0}),
[
np.asarray(
[
[1.5089518, 1.48439076, -4.7820262],
[2.04832468, 0.50791564, -1.58892269],
]
)
],
)
def test_XOR_inplace():
dtype = [
"int8",
"int16",
"int32",
"int64",
]
xor_inplace = Elemwise(scalar_xor, {0: 0})
for dtype in dtype:
x, y = vector(dtype=dtype), vector(dtype=dtype)
l = np.asarray([0, 0, 1, 1], dtype=dtype)
r = np.asarray([0, 1, 0, 1], dtype=dtype)
ix = x
ix = xor_inplace(ix, y)
gn = inplace_func([x, y], ix)
_ = gn(l, r)
# test the in-place stuff
assert np.all(l == np.asarray([0, 1, 1, 0])), l
import numpy as np
import pytest
from pytensor import config
from pytensor.scalar.basic import round_half_away_from_zero_vec, upcast
from pytensor.tensor.inplace import (
abs_inplace,
add_inplace,
arccos_inplace,
arccosh_inplace,
arcsin_inplace,
arcsinh_inplace,
arctan2_inplace,
arctan_inplace,
arctanh_inplace,
ceil_inplace,
conj_inplace,
cos_inplace,
cosh_inplace,
deg2rad_inplace,
exp2_inplace,
exp_inplace,
expm1_inplace,
floor_inplace,
int_div_inplace,
log1p_inplace,
log2_inplace,
log10_inplace,
log_inplace,
maximum_inplace,
minimum_inplace,
mod_inplace,
mul_inplace,
neg_inplace,
pow_inplace,
rad2deg_inplace,
reciprocal_inplace,
round_half_away_from_zero_inplace,
round_half_to_even_inplace,
sign_inplace,
sin_inplace,
sinh_inplace,
sqr_inplace,
sqrt_inplace,
sub_inplace,
tan_inplace,
tanh_inplace,
true_div_inplace,
trunc_inplace,
xor_inplace,
)
from pytensor.tensor.type import vector
from tests import unittest_tools as utt
from tests.tensor.utils import (
_bad_build_broadcast_binary_normal,
_bad_runtime_broadcast_binary_normal,
_bad_runtime_reciprocal,
_good_broadcast_binary_arctan2,
_good_broadcast_binary_normal,
_good_broadcast_div_mod_normal_float_inplace,
_good_broadcast_pow_normal_float_pow,
_good_broadcast_unary_arccosh,
_good_broadcast_unary_arcsin_float,
_good_broadcast_unary_arctanh,
_good_broadcast_unary_normal,
_good_broadcast_unary_normal_abs,
_good_broadcast_unary_normal_float,
_good_broadcast_unary_normal_float_no_complex,
_good_broadcast_unary_normal_float_no_empty_no_complex,
_good_broadcast_unary_normal_no_complex,
_good_broadcast_unary_positive_float,
_good_broadcast_unary_tan,
_good_broadcast_unary_wide_float,
_good_reciprocal_inplace,
_numpy_true_div,
angle_eps,
check_floatX,
copymod,
div_grad_rtol,
ignore_isfinite_mode,
inplace_func,
makeBroadcastTester,
upcast_float16_ufunc,
)
TestAddInplaceBroadcast = makeBroadcastTester(
op=add_inplace,
expected=lambda x, y: x + y,
good=_good_broadcast_binary_normal,
bad_build=_bad_build_broadcast_binary_normal,
bad_runtime=_bad_runtime_broadcast_binary_normal,
inplace=True,
)
TestSubInplaceBroadcast = makeBroadcastTester(
op=sub_inplace,
expected=lambda x, y: x - y,
good=_good_broadcast_binary_normal,
bad_build=_bad_build_broadcast_binary_normal,
bad_runtime=_bad_runtime_broadcast_binary_normal,
inplace=True,
)
TestMaximumInplaceBroadcast = makeBroadcastTester(
op=maximum_inplace,
expected=np.maximum,
good=_good_broadcast_binary_normal,
bad_build=_bad_build_broadcast_binary_normal,
bad_runtime=_bad_runtime_broadcast_binary_normal,
inplace=True,
)
TestMinimumInplaceBroadcast = makeBroadcastTester(
op=minimum_inplace,
expected=np.minimum,
good=_good_broadcast_binary_normal,
bad_build=_bad_build_broadcast_binary_normal,
bad_runtime=_bad_runtime_broadcast_binary_normal,
inplace=True,
)
TestMulInplaceBroadcast = makeBroadcastTester(
op=mul_inplace,
expected=lambda x, y: x * y,
good=_good_broadcast_binary_normal,
bad_build=_bad_build_broadcast_binary_normal,
bad_runtime=_bad_runtime_broadcast_binary_normal,
inplace=True,
)
TestTrueDivInplaceBroadcast = makeBroadcastTester(
op=true_div_inplace,
expected=_numpy_true_div,
good=copymod(
_good_broadcast_div_mod_normal_float_inplace,
# The output is now in float, we cannot work inplace on an int.
without=["integer", "uint8", "uint16", "int8"],
),
grad_rtol=div_grad_rtol,
inplace=True,
)
TestReciprocalInplaceBroadcast = makeBroadcastTester(
op=reciprocal_inplace,
expected=lambda x: _numpy_true_div(np.int8(1), x),
good=_good_reciprocal_inplace,
bad_runtime=_bad_runtime_reciprocal,
grad_rtol=div_grad_rtol,
inplace=True,
)
TestModInplaceBroadcast = makeBroadcastTester(
op=mod_inplace,
expected=lambda x, y: np.asarray(x % y, dtype=upcast(x.dtype, y.dtype)),
good=copymod(
_good_broadcast_div_mod_normal_float_inplace, ["complex1", "complex2"]
),
grad_eps=1e-5,
inplace=True,
)
TestPowInplaceBroadcast = makeBroadcastTester(
op=pow_inplace,
expected=lambda x, y: x**y,
good=_good_broadcast_pow_normal_float_pow,
inplace=True,
mode=ignore_isfinite_mode,
)
TestNegInplaceBroadcast = makeBroadcastTester(
op=neg_inplace,
expected=lambda x: -x,
good=_good_broadcast_unary_normal,
inplace=True,
)
TestSgnInplaceBroadcast = makeBroadcastTester(
op=sign_inplace,
expected=np.sign,
good=_good_broadcast_unary_normal_no_complex,
inplace=True,
)
TestAbsInplaceBroadcast = makeBroadcastTester(
op=abs_inplace,
expected=lambda x: np.abs(x),
good=_good_broadcast_unary_normal_abs,
inplace=True,
)
TestIntDivInplaceBroadcast = makeBroadcastTester(
op=int_div_inplace,
expected=lambda x, y: check_floatX((x, y), x // y),
good=_good_broadcast_div_mod_normal_float_inplace,
# I don't test the grad as the output is always an integer
# (this is not a continuous output).
# grad=_grad_broadcast_div_mod_normal,
inplace=True,
)
TestCeilInplaceBroadcast = makeBroadcastTester(
op=ceil_inplace,
expected=upcast_float16_ufunc(np.ceil),
good=copymod(
_good_broadcast_unary_normal_no_complex,
without=["integers", "int8", "uint8", "uint16"],
),
# corner cases includes a lot of integers: points where Ceil is not
# continuous (not differentiable)
inplace=True,
)
TestFloorInplaceBroadcast = makeBroadcastTester(
op=floor_inplace,
expected=upcast_float16_ufunc(np.floor),
good=copymod(
_good_broadcast_unary_normal_no_complex,
without=["integers", "int8", "uint8", "uint16"],
),
inplace=True,
)
TestTruncInplaceBroadcast = makeBroadcastTester(
op=trunc_inplace,
expected=upcast_float16_ufunc(np.trunc),
good=_good_broadcast_unary_normal_no_complex,
inplace=True,
)
TestRoundHalfToEvenInplaceBroadcast = makeBroadcastTester(
op=round_half_to_even_inplace,
expected=np.round,
good=_good_broadcast_unary_normal_float_no_complex,
inplace=True,
)
TestRoundHalfAwayFromZeroInplaceBroadcast = makeBroadcastTester(
op=round_half_away_from_zero_inplace,
expected=lambda a: round_half_away_from_zero_vec(a),
good=_good_broadcast_unary_normal_float_no_empty_no_complex,
inplace=True,
)
TestSqrInplaceBroadcast = makeBroadcastTester(
op=sqr_inplace,
expected=np.square,
good=_good_broadcast_unary_normal,
inplace=True,
)
TestExpInplaceBroadcast = makeBroadcastTester(
op=exp_inplace,
expected=np.exp,
good=_good_broadcast_unary_normal_float,
inplace=True,
)
TestExp2InplaceBroadcast = makeBroadcastTester(
op=exp2_inplace,
expected=np.exp2,
good=_good_broadcast_unary_normal_float,
inplace=True,
)
TestExpm1InplaceBroadcast = makeBroadcastTester(
op=expm1_inplace,
expected=np.expm1,
good=_good_broadcast_unary_normal_float,
inplace=True,
)
TestLogInplaceBroadcast = makeBroadcastTester(
op=log_inplace,
expected=np.log,
good=_good_broadcast_unary_positive_float,
inplace=True,
)
TestLog2InplaceBroadcast = makeBroadcastTester(
op=log2_inplace,
expected=np.log2,
good=_good_broadcast_unary_positive_float,
inplace=True,
)
TestLog10InplaceBroadcast = makeBroadcastTester(
op=log10_inplace,
expected=np.log10,
good=_good_broadcast_unary_positive_float,
inplace=True,
)
TestLog1pInplaceBroadcast = makeBroadcastTester(
op=log1p_inplace,
expected=np.log1p,
good=_good_broadcast_unary_positive_float,
inplace=True,
)
TestSqrtInplaceBroadcast = makeBroadcastTester(
op=sqrt_inplace,
expected=np.sqrt,
good=_good_broadcast_unary_positive_float,
inplace=True,
)
TestDeg2radInplaceBroadcast = makeBroadcastTester(
op=deg2rad_inplace,
expected=np.deg2rad,
good=_good_broadcast_unary_normal_float_no_complex,
inplace=True,
eps=angle_eps,
)
TestRad2degInplaceBroadcast = makeBroadcastTester(
op=rad2deg_inplace,
expected=np.rad2deg,
good=_good_broadcast_unary_normal_float_no_complex,
inplace=True,
eps=angle_eps,
)
TestSinInplaceBroadcast = makeBroadcastTester(
op=sin_inplace,
expected=np.sin,
good=_good_broadcast_unary_wide_float,
inplace=True,
)
TestArcsinInplaceBroadcast = makeBroadcastTester(
op=arcsin_inplace,
expected=np.arcsin,
good=_good_broadcast_unary_arcsin_float,
inplace=True,
)
TestCosInplaceBroadcast = makeBroadcastTester(
op=cos_inplace,
expected=np.cos,
good=_good_broadcast_unary_wide_float,
inplace=True,
)
TestArccosInplaceBroadcast = makeBroadcastTester(
op=arccos_inplace,
expected=np.arccos,
good=_good_broadcast_unary_arcsin_float,
inplace=True,
)
TestTanInplaceBroadcast = makeBroadcastTester(
op=tan_inplace,
expected=np.tan,
good=copymod(
_good_broadcast_unary_tan, without=["integers", "int8", "uint8", "uint16"]
),
inplace=True,
)
TestArctanInplaceBroadcast = makeBroadcastTester(
op=arctan_inplace,
expected=np.arctan,
good=_good_broadcast_unary_wide_float,
inplace=True,
)
TestArctan2InplaceBroadcast = makeBroadcastTester(
op=arctan2_inplace,
expected=np.arctan2,
good=copymod(
_good_broadcast_binary_arctan2,
without=["integers", "int8", "uint8", "uint16", "dtype_mixup_2"],
),
inplace=True,
)
TestCoshInplaceBroadcast = makeBroadcastTester(
op=cosh_inplace,
expected=np.cosh,
good=_good_broadcast_unary_normal_float,
inplace=True,
)
TestArccoshInplaceBroadcast = makeBroadcastTester(
op=arccosh_inplace,
expected=np.arccosh,
good=copymod(_good_broadcast_unary_arccosh, without=["integers", "uint8"]),
inplace=True,
)
TestSinhInplaceBroadcast = makeBroadcastTester(
op=sinh_inplace,
expected=np.sinh,
good=_good_broadcast_unary_normal_float,
inplace=True,
)
TestArcsinhInplaceBroadcast = makeBroadcastTester(
op=arcsinh_inplace,
expected=np.arcsinh,
good=_good_broadcast_unary_normal_float,
inplace=True,
)
TestTanhInplaceBroadcast = makeBroadcastTester(
op=tanh_inplace,
expected=np.tanh,
good=_good_broadcast_unary_normal_float,
inplace=True,
)
TestArctanhInplaceBroadcast = makeBroadcastTester(
op=arctanh_inplace,
expected=np.arctanh,
good=copymod(
_good_broadcast_unary_arctanh, without=["integers", "int8", "uint8", "uint16"]
),
inplace=True,
)
TestConjInplaceBroadcast = makeBroadcastTester(
op=conj_inplace,
expected=np.conj,
good=_good_broadcast_unary_normal,
inplace=True,
)
@pytest.mark.xfail(
config.cycle_detection == "fast" and config.mode != "FAST_COMPILE",
reason="Cycle detection is fast and mode is FAST_COMPILE",
)
def test_exp_inplace_grad_1():
utt.verify_grad(
exp_inplace,
[
np.asarray(
[
[1.5089518, 1.48439076, -4.7820262],
[2.04832468, 0.50791564, -1.58892269],
]
)
],
)
def test_XOR_inplace():
dtype = [
"int8",
"int16",
"int32",
"int64",
]
for dtype in dtype:
x, y = vector(dtype=dtype), vector(dtype=dtype)
l = np.asarray([0, 0, 1, 1], dtype=dtype)
r = np.asarray([0, 1, 0, 1], dtype=dtype)
ix = x
ix = xor_inplace(ix, y)
gn = inplace_func([x, y], ix)
_ = gn(l, r)
# test the in-place stuff
assert np.all(l == np.asarray([0, 1, 1, 0])), l
...@@ -12,13 +12,12 @@ from pytensor.compile.mode import get_default_mode ...@@ -12,13 +12,12 @@ from pytensor.compile.mode import get_default_mode
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.gradient import NullTypeGradError, verify_grad from pytensor.gradient import NullTypeGradError, verify_grad
from pytensor.scalar import ScalarLoop from pytensor.scalar import ScalarLoop
from pytensor.tensor import gammaincc, inplace, kn, kv, kve, vector from pytensor.tensor import gammaincc, kn, kv, kve, vector
from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.elemwise import Elemwise
from tests import unittest_tools as utt from tests import unittest_tools as utt
from tests.tensor.utils import ( from tests.tensor.utils import (
_good_broadcast_unary_chi2sf, _good_broadcast_unary_chi2sf,
_good_broadcast_unary_normal, _good_broadcast_unary_normal,
_good_broadcast_unary_normal_float,
_good_broadcast_unary_normal_float_no_complex, _good_broadcast_unary_normal_float_no_complex,
_good_broadcast_unary_normal_float_no_complex_small_neg_range, _good_broadcast_unary_normal_float_no_complex_small_neg_range,
_good_broadcast_unary_normal_no_complex, _good_broadcast_unary_normal_no_complex,
...@@ -85,14 +84,6 @@ TestErfBroadcast = makeBroadcastTester( ...@@ -85,14 +84,6 @@ TestErfBroadcast = makeBroadcastTester(
eps=2e-10, eps=2e-10,
mode=mode_no_scipy, mode=mode_no_scipy,
) )
TestErfInplaceBroadcast = makeBroadcastTester(
op=inplace.erf_inplace,
expected=expected_erf,
good=_good_broadcast_unary_normal_float,
mode=mode_no_scipy,
eps=2e-10,
inplace=True,
)
TestErfcBroadcast = makeBroadcastTester( TestErfcBroadcast = makeBroadcastTester(
op=pt.erfc, op=pt.erfc,
...@@ -102,14 +93,6 @@ TestErfcBroadcast = makeBroadcastTester( ...@@ -102,14 +93,6 @@ TestErfcBroadcast = makeBroadcastTester(
eps=2e-10, eps=2e-10,
mode=mode_no_scipy, mode=mode_no_scipy,
) )
TestErfcInplaceBroadcast = makeBroadcastTester(
op=inplace.erfc_inplace,
expected=expected_erfc,
good=_good_broadcast_unary_normal_float_no_complex,
eps=2e-10,
mode=mode_no_scipy,
inplace=True,
)
TestErfcxBroadcast = makeBroadcastTester( TestErfcxBroadcast = makeBroadcastTester(
op=pt.erfcx, op=pt.erfcx,
...@@ -119,14 +102,6 @@ TestErfcxBroadcast = makeBroadcastTester( ...@@ -119,14 +102,6 @@ TestErfcxBroadcast = makeBroadcastTester(
eps=2e-10, eps=2e-10,
mode=mode_no_scipy, mode=mode_no_scipy,
) )
TestErfcxInplaceBroadcast = makeBroadcastTester(
op=inplace.erfcx_inplace,
expected=expected_erfcx,
good=_good_broadcast_unary_normal_float_no_complex_small_neg_range,
eps=2e-10,
mode=mode_no_scipy,
inplace=True,
)
TestErfinvBroadcast = makeBroadcastTester( TestErfinvBroadcast = makeBroadcastTester(
op=pt.erfinv, op=pt.erfinv,
...@@ -192,14 +167,6 @@ TestOwensTBroadcast = makeBroadcastTester( ...@@ -192,14 +167,6 @@ TestOwensTBroadcast = makeBroadcastTester(
eps=2e-10, eps=2e-10,
mode=mode_no_scipy, mode=mode_no_scipy,
) )
TestOwensTInplaceBroadcast = makeBroadcastTester(
op=inplace.owens_t_inplace,
expected=expected_owenst,
good=_good_broadcast_binary_owenst,
eps=2e-10,
mode=mode_no_scipy,
inplace=True,
)
rng = np.random.default_rng(seed=utt.fetch_seed()) rng = np.random.default_rng(seed=utt.fetch_seed())
_good_broadcast_unary_gammaln = dict( _good_broadcast_unary_gammaln = dict(
...@@ -223,14 +190,6 @@ TestGammaBroadcast = makeBroadcastTester( ...@@ -223,14 +190,6 @@ TestGammaBroadcast = makeBroadcastTester(
mode=mode_no_scipy, mode=mode_no_scipy,
eps=1e-5, eps=1e-5,
) )
TestGammaInplaceBroadcast = makeBroadcastTester(
op=inplace.gamma_inplace,
expected=expected_gamma,
good=_good_broadcast_unary_gammaln,
mode=mode_no_scipy,
eps=1e-5,
inplace=True,
)
TestGammalnBroadcast = makeBroadcastTester( TestGammalnBroadcast = makeBroadcastTester(
op=pt.gammaln, op=pt.gammaln,
...@@ -240,14 +199,6 @@ TestGammalnBroadcast = makeBroadcastTester( ...@@ -240,14 +199,6 @@ TestGammalnBroadcast = makeBroadcastTester(
eps=2e-10, eps=2e-10,
mode=mode_no_scipy, mode=mode_no_scipy,
) )
TestGammalnInplaceBroadcast = makeBroadcastTester(
op=inplace.gammaln_inplace,
expected=expected_gammaln,
good=_good_broadcast_unary_gammaln,
eps=2e-10,
mode=mode_no_scipy,
inplace=True,
)
rng = np.random.default_rng(seed=utt.fetch_seed()) rng = np.random.default_rng(seed=utt.fetch_seed())
_good_broadcast_unary_psi = dict( _good_broadcast_unary_psi = dict(
...@@ -265,14 +216,6 @@ TestPsiBroadcast = makeBroadcastTester( ...@@ -265,14 +216,6 @@ TestPsiBroadcast = makeBroadcastTester(
eps=2e-10, eps=2e-10,
mode=mode_no_scipy, mode=mode_no_scipy,
) )
TestPsiInplaceBroadcast = makeBroadcastTester(
op=inplace.psi_inplace,
expected=expected_psi,
good=_good_broadcast_unary_psi,
eps=2e-10,
mode=mode_no_scipy,
inplace=True,
)
_good_broadcast_unary_tri_gamma = _good_broadcast_unary_psi _good_broadcast_unary_tri_gamma = _good_broadcast_unary_psi
...@@ -283,14 +226,6 @@ TestTriGammaBroadcast = makeBroadcastTester( ...@@ -283,14 +226,6 @@ TestTriGammaBroadcast = makeBroadcastTester(
eps=2e-8, eps=2e-8,
mode=mode_no_scipy, mode=mode_no_scipy,
) )
TestTriGammaInplaceBroadcast = makeBroadcastTester(
op=inplace.tri_gamma_inplace,
expected=expected_tri_gamma,
good=_good_broadcast_unary_tri_gamma,
eps=2e-8,
mode=mode_no_scipy,
inplace=True,
)
TestChi2SFBroadcast = makeBroadcastTester( TestChi2SFBroadcast = makeBroadcastTester(
op=pt.chi2sf, op=pt.chi2sf,
...@@ -343,15 +278,6 @@ TestGammaIncBroadcast = makeBroadcastTester( ...@@ -343,15 +278,6 @@ TestGammaIncBroadcast = makeBroadcastTester(
mode=mode_no_scipy, mode=mode_no_scipy,
) )
TestGammaIncInplaceBroadcast = makeBroadcastTester(
op=inplace.gammainc_inplace,
expected=expected_gammainc,
good=_good_broadcast_binary_gamma,
eps=2e-8,
mode=mode_no_scipy,
inplace=True,
)
TestGammaInccBroadcast = makeBroadcastTester( TestGammaInccBroadcast = makeBroadcastTester(
op=pt.gammaincc, op=pt.gammaincc,
expected=expected_gammaincc, expected=expected_gammaincc,
...@@ -361,15 +287,6 @@ TestGammaInccBroadcast = makeBroadcastTester( ...@@ -361,15 +287,6 @@ TestGammaInccBroadcast = makeBroadcastTester(
mode=mode_no_scipy, mode=mode_no_scipy,
) )
TestGammaInccInplaceBroadcast = makeBroadcastTester(
op=inplace.gammaincc_inplace,
expected=expected_gammaincc,
good=_good_broadcast_binary_gamma,
eps=2e-8,
mode=mode_no_scipy,
inplace=True,
)
def test_gammainc_ddk_tabulated_values(): def test_gammainc_ddk_tabulated_values():
# This test replicates part of the old STAN test: # This test replicates part of the old STAN test:
...@@ -447,15 +364,6 @@ TestGammaUBroadcast = makeBroadcastTester( ...@@ -447,15 +364,6 @@ TestGammaUBroadcast = makeBroadcastTester(
mode=mode_no_scipy, mode=mode_no_scipy,
) )
TestGammaUInplaceBroadcast = makeBroadcastTester(
op=inplace.gammau_inplace,
expected=expected_gammau,
good=_good_broadcast_binary_gamma,
eps=2e-8,
mode=mode_no_scipy,
inplace=True,
)
TestGammaLBroadcast = makeBroadcastTester( TestGammaLBroadcast = makeBroadcastTester(
op=pt.gammal, op=pt.gammal,
expected=expected_gammal, expected=expected_gammal,
...@@ -464,15 +372,6 @@ TestGammaLBroadcast = makeBroadcastTester( ...@@ -464,15 +372,6 @@ TestGammaLBroadcast = makeBroadcastTester(
mode=mode_no_scipy, mode=mode_no_scipy,
) )
TestGammaLInplaceBroadcast = makeBroadcastTester(
op=inplace.gammal_inplace,
expected=expected_gammal,
good=_good_broadcast_binary_gamma,
eps=2e-8,
mode=mode_no_scipy,
inplace=True,
)
rng = np.random.default_rng(seed=utt.fetch_seed()) rng = np.random.default_rng(seed=utt.fetch_seed())
_good_broadcast_binary_gamma = dict( _good_broadcast_binary_gamma = dict(
normal=( normal=(
...@@ -490,15 +389,6 @@ TestGammaIncInvBroadcast = makeBroadcastTester( ...@@ -490,15 +389,6 @@ TestGammaIncInvBroadcast = makeBroadcastTester(
mode=mode_no_scipy, mode=mode_no_scipy,
) )
TestGammaIncInvInplaceBroadcast = makeBroadcastTester(
op=inplace.gammaincinv_inplace,
expected=expected_gammaincinv,
good=_good_broadcast_binary_gamma,
eps=2e-8,
mode=mode_no_scipy,
inplace=True,
)
TestGammaInccInvBroadcast = makeBroadcastTester( TestGammaInccInvBroadcast = makeBroadcastTester(
op=pt.gammainccinv, op=pt.gammainccinv,
expected=expected_gammainccinv, expected=expected_gammainccinv,
...@@ -507,15 +397,6 @@ TestGammaInccInvBroadcast = makeBroadcastTester( ...@@ -507,15 +397,6 @@ TestGammaInccInvBroadcast = makeBroadcastTester(
mode=mode_no_scipy, mode=mode_no_scipy,
) )
TestGammaInccInvInplaceBroadcast = makeBroadcastTester(
op=inplace.gammainccinv_inplace,
expected=expected_gammainccinv,
good=_good_broadcast_binary_gamma,
eps=2e-8,
mode=mode_no_scipy,
inplace=True,
)
rng = np.random.default_rng(seed=utt.fetch_seed()) rng = np.random.default_rng(seed=utt.fetch_seed())
_good_broadcast_unary_bessel = dict( _good_broadcast_unary_bessel = dict(
normal=(random_ranged(-10, 10, (2, 3), rng=rng),), normal=(random_ranged(-10, 10, (2, 3), rng=rng),),
...@@ -562,15 +443,6 @@ TestJ0Broadcast = makeBroadcastTester( ...@@ -562,15 +443,6 @@ TestJ0Broadcast = makeBroadcastTester(
mode=mode_no_scipy, mode=mode_no_scipy,
) )
TestJ0InplaceBroadcast = makeBroadcastTester(
op=inplace.j0_inplace,
expected=expected_j0,
good=_good_broadcast_unary_bessel,
eps=2e-10,
mode=mode_no_scipy,
inplace=True,
)
TestJ1Broadcast = makeBroadcastTester( TestJ1Broadcast = makeBroadcastTester(
op=pt.j1, op=pt.j1,
expected=expected_j1, expected=expected_j1,
...@@ -580,15 +452,6 @@ TestJ1Broadcast = makeBroadcastTester( ...@@ -580,15 +452,6 @@ TestJ1Broadcast = makeBroadcastTester(
mode=mode_no_scipy, mode=mode_no_scipy,
) )
TestJ1InplaceBroadcast = makeBroadcastTester(
op=inplace.j1_inplace,
expected=expected_j1,
good=_good_broadcast_unary_bessel,
eps=2e-10,
mode=mode_no_scipy,
inplace=True,
)
TestJvBroadcast = makeBroadcastTester( TestJvBroadcast = makeBroadcastTester(
op=pt.jv, op=pt.jv,
expected=expected_jv, expected=expected_jv,
...@@ -597,15 +460,6 @@ TestJvBroadcast = makeBroadcastTester( ...@@ -597,15 +460,6 @@ TestJvBroadcast = makeBroadcastTester(
mode=mode_no_scipy, mode=mode_no_scipy,
) )
TestJvInplaceBroadcast = makeBroadcastTester(
op=inplace.jv_inplace,
expected=expected_jv,
good=_good_broadcast_binary_bessel,
eps=2e-10,
mode=mode_no_scipy,
inplace=True,
)
def test_verify_jv_grad(): def test_verify_jv_grad():
# Verify Jv gradient. # Verify Jv gradient.
...@@ -628,15 +482,6 @@ TestI0Broadcast = makeBroadcastTester( ...@@ -628,15 +482,6 @@ TestI0Broadcast = makeBroadcastTester(
mode=mode_no_scipy, mode=mode_no_scipy,
) )
TestI0InplaceBroadcast = makeBroadcastTester(
op=inplace.i0_inplace,
expected=expected_i0,
good=_good_broadcast_unary_bessel,
eps=2e-10,
mode=mode_no_scipy,
inplace=True,
)
TestI1Broadcast = makeBroadcastTester( TestI1Broadcast = makeBroadcastTester(
op=pt.i1, op=pt.i1,
expected=expected_i1, expected=expected_i1,
...@@ -646,15 +491,6 @@ TestI1Broadcast = makeBroadcastTester( ...@@ -646,15 +491,6 @@ TestI1Broadcast = makeBroadcastTester(
mode=mode_no_scipy, mode=mode_no_scipy,
) )
TestI1InplaceBroadcast = makeBroadcastTester(
op=inplace.i1_inplace,
expected=expected_i1,
good=_good_broadcast_unary_bessel,
eps=2e-10,
mode=mode_no_scipy,
inplace=True,
)
TestIvBroadcast = makeBroadcastTester( TestIvBroadcast = makeBroadcastTester(
op=pt.iv, op=pt.iv,
expected=expected_iv, expected=expected_iv,
...@@ -663,15 +499,6 @@ TestIvBroadcast = makeBroadcastTester( ...@@ -663,15 +499,6 @@ TestIvBroadcast = makeBroadcastTester(
mode=mode_no_scipy, mode=mode_no_scipy,
) )
TestIvInplaceBroadcast = makeBroadcastTester(
op=inplace.iv_inplace,
expected=expected_iv,
good=_good_broadcast_binary_bessel,
eps=2e-10,
mode=mode_no_scipy,
inplace=True,
)
TestIveBroadcast = makeBroadcastTester( TestIveBroadcast = makeBroadcastTester(
op=pt.ive, op=pt.ive,
expected=expected_ive, expected=expected_ive,
...@@ -680,15 +507,6 @@ TestIveBroadcast = makeBroadcastTester( ...@@ -680,15 +507,6 @@ TestIveBroadcast = makeBroadcastTester(
mode=mode_no_scipy, mode=mode_no_scipy,
) )
TestIveInplaceBroadcast = makeBroadcastTester(
op=inplace.ive_inplace,
expected=expected_ive,
good=_good_broadcast_binary_bessel,
eps=2e-10,
mode=mode_no_scipy,
inplace=True,
)
def test_verify_iv_grad(): def test_verify_iv_grad():
# Verify Iv gradient. # Verify Iv gradient.
...@@ -721,15 +539,6 @@ TestSigmoidBroadcast = makeBroadcastTester( ...@@ -721,15 +539,6 @@ TestSigmoidBroadcast = makeBroadcastTester(
eps=1e-8, eps=1e-8,
) )
TestSigmoidInplaceBroadcast = makeBroadcastTester(
op=inplace.sigmoid_inplace,
expected=expected_sigmoid,
good=_good_broadcast_unary_normal_no_complex,
grad=_grad_broadcast_unary_normal,
eps=1e-8,
inplace=True,
)
class TestSigmoid: class TestSigmoid:
def test_elemwise(self): def test_elemwise(self):
...@@ -758,15 +567,6 @@ TestSoftplusBroadcast = makeBroadcastTester( ...@@ -758,15 +567,6 @@ TestSoftplusBroadcast = makeBroadcastTester(
eps=1e-8, eps=1e-8,
) )
TestSoftplusInplaceBroadcast = makeBroadcastTester(
op=inplace.softplus_inplace,
expected=expected_sofplus,
good=_good_broadcast_unary_softplus,
grad=_grad_broadcast_unary_normal,
eps=1e-8,
inplace=True,
)
class TestSoftplus: class TestSoftplus:
def test_elemwise(self): def test_elemwise(self):
...@@ -805,14 +605,6 @@ TestLog1mexpBroadcast = makeBroadcastTester( ...@@ -805,14 +605,6 @@ TestLog1mexpBroadcast = makeBroadcastTester(
eps=1e-8, eps=1e-8,
) )
TestLog1mexpInplaceBroadcast = makeBroadcastTester(
op=inplace.log1mexp_inplace,
expected=expected_log1mexp,
good=_good_broadcast_unary_log1mexp,
eps=1e-8,
inplace=True,
)
_good_broadcast_ternary_betainc = dict( _good_broadcast_ternary_betainc = dict(
normal=( normal=(
random_ranged(0, 1000, (2, 3)), random_ranged(0, 1000, (2, 3)),
...@@ -828,14 +620,6 @@ TestBetaincBroadcast = makeBroadcastTester( ...@@ -828,14 +620,6 @@ TestBetaincBroadcast = makeBroadcastTester(
grad=_good_broadcast_ternary_betainc, grad=_good_broadcast_ternary_betainc,
) )
TestBetaincInplaceBroadcast = makeBroadcastTester(
op=inplace.betainc_inplace,
expected=special.betainc,
good=_good_broadcast_ternary_betainc,
grad=_good_broadcast_ternary_betainc,
inplace=True,
)
class TestBetaIncGrad: class TestBetaIncGrad:
def test_stan_grad_partial(self): def test_stan_grad_partial(self):
...@@ -926,13 +710,6 @@ TestBetaincinvBroadcast = makeBroadcastTester( ...@@ -926,13 +710,6 @@ TestBetaincinvBroadcast = makeBroadcastTester(
good=_good_broadcast_ternary_betaincinv, good=_good_broadcast_ternary_betaincinv,
) )
TestBetaincinvInplaceBroadcast = makeBroadcastTester(
op=inplace.betaincinv_inplace,
expected=special.betaincinv,
good=_good_broadcast_ternary_betaincinv,
inplace=True,
)
_good_broadcast_quaternary_hyp2f1 = dict( _good_broadcast_quaternary_hyp2f1 = dict(
normal=( normal=(
random_ranged(0, 20, (2, 3)), random_ranged(0, 20, (2, 3)),
...@@ -949,13 +726,6 @@ TestHyp2F1Broadcast = makeBroadcastTester( ...@@ -949,13 +726,6 @@ TestHyp2F1Broadcast = makeBroadcastTester(
grad=_good_broadcast_quaternary_hyp2f1, grad=_good_broadcast_quaternary_hyp2f1,
) )
TestHyp2F1InplaceBroadcast = makeBroadcastTester(
op=inplace.hyp2f1_inplace,
expected=expected_hyp2f1,
good=_good_broadcast_quaternary_hyp2f1,
inplace=True,
)
class TestHyp2F1Grad: class TestHyp2F1Grad:
few_iters_case = ( few_iters_case = (
......
...@@ -672,7 +672,9 @@ def makeTester( ...@@ -672,7 +672,9 @@ def makeTester(
return Checker return Checker
def makeBroadcastTester(op, expected, checks=None, name=None, **kwargs): def makeBroadcastTester(
op, expected, checks=None, name=None, *, inplace=False, **kwargs
):
if checks is None: if checks is None:
checks = {} checks = {}
if name is None: if name is None:
...@@ -695,22 +697,20 @@ def makeBroadcastTester(op, expected, checks=None, name=None, **kwargs): ...@@ -695,22 +697,20 @@ def makeBroadcastTester(op, expected, checks=None, name=None, **kwargs):
# cases we need to add it manually. # cases we need to add it manually.
if not name.endswith("Tester"): if not name.endswith("Tester"):
name += "Tester" name += "Tester"
if "inplace" in kwargs: if inplace:
if kwargs["inplace"]: _expected = expected
_expected = expected if not isinstance(_expected, dict):
if not isinstance(_expected, dict):
def expected(*inputs):
def expected(*inputs): return np.array(_expected(*inputs), dtype=inputs[0].dtype)
return np.array(_expected(*inputs), dtype=inputs[0].dtype)
def inplace_check(inputs, outputs):
def inplace_check(inputs, outputs): # this used to be inputs[0] is output[0]
# this used to be inputs[0] is output[0] # I changed it so that it was easier to satisfy by the
# I changed it so that it was easier to satisfy by the # DebugMode
# DebugMode return np.all(inputs[0] == outputs[0])
return np.all(inputs[0] == outputs[0])
checks = dict(checks, inplace_check=inplace_check)
checks = dict(checks, inplace_check=inplace_check)
del kwargs["inplace"]
return makeTester(name, op, expected, checks, **kwargs) return makeTester(name, op, expected, checks, **kwargs)
...@@ -815,6 +815,7 @@ _good_broadcast_unary_normal_no_complex = dict( ...@@ -815,6 +815,7 @@ _good_broadcast_unary_normal_no_complex = dict(
big_scalar=[np.arange(17.0, 29.0, 0.5, dtype=config.floatX)], big_scalar=[np.arange(17.0, 29.0, 0.5, dtype=config.floatX)],
) )
# FIXME: Why is this empty?
_bad_build_broadcast_binary_normal = dict() _bad_build_broadcast_binary_normal = dict()
_bad_runtime_broadcast_binary_normal = dict( _bad_runtime_broadcast_binary_normal = dict(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论