提交 e0a2a865 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Remove predefined inplace Elemwise Ops and redundant tests

上级 42e8490c
......@@ -84,13 +84,13 @@ jobs:
install-mlx: [0]
install-xarray: [0]
part:
- "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/xtensor"
- "tests --ignore=tests/scan --ignore=tests/tensor --ignore=tests/xtensor"
- "tests/scan"
- "tests/tensor --ignore=tests/tensor/rewriting --ignore=tests/tensor/test_math.py --ignore=tests/tensor/test_basic.py --ignore=tests/tensor/test_inplace.py --ignore=tests/tensor/conv --ignore=tests/tensor/test_blas.py --ignore=tests/tensor/test_elemwise.py --ignore=tests/tensor/test_math_scipy.py"
- "tests/tensor/rewriting"
- "tests/tensor --ignore=tests/tensor/test_basic.py --ignore=tests/tensor/test_elemwise.py --ignore=tests/tensor/test_math.py --ignore=tests/tensor/test_math_scipy.py --ignore=tests/tensor/test_blas.py --ignore=tests/tensor/conv --ignore=tests/tensor/rewriting"
- "tests/tensor/test_basic.py tests/tensor/test_elemwise.py"
- "tests/tensor/test_math.py"
- "tests/tensor/test_basic.py tests/tensor/test_inplace.py tests/tensor/conv"
- "tests/tensor/test_blas.py tests/tensor/test_elemwise.py tests/tensor/test_math_scipy.py"
- "tests/tensor/test_math_scipy.py tests/tensor/test_blas.py tests/tensor/conv"
- "tests/tensor/rewriting"
exclude:
- python-version: "3.11"
fast-compile: 1
......@@ -167,7 +167,7 @@ jobs:
install-numba: 0
install-jax: 0
install-torch: 0
part: "tests/tensor/test_blas.py tests/tensor/test_elemwise.py tests/tensor/test_math_scipy.py"
part: "tests/tensor/test_elemwise.py tests/tensor/test_math_scipy.py tests/tensor/test_blas.py"
steps:
- uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
......
......@@ -20,7 +20,7 @@ from pytensor.misc.frozendict import frozendict
from pytensor.printing import Printer, pprint
from pytensor.scalar import get_scalar_type
from pytensor.scalar.basic import identity as scalar_identity
from pytensor.scalar.basic import int64, transfer_type, upcast
from pytensor.scalar.basic import int64, upcast
from pytensor.tensor import elemwise_cgen as cgen
from pytensor.tensor import get_vector_length
from pytensor.tensor.basic import _get_vector_length, as_tensor_variable
......@@ -1634,17 +1634,12 @@ def scalar_elemwise(*symbol, nfunc=None, nin=None, nout=None, symbolname=None):
symbolname = symbolname or symbol.__name__
if symbolname.endswith("_inplace"):
base_symbol_name = symbolname[: -len("_inplace")]
scalar_op = getattr(scalar, base_symbol_name)
inplace_scalar_op = scalar_op.__class__(transfer_type(0))
rval = Elemwise(
inplace_scalar_op,
{0: 0},
nfunc_spec=(nfunc and (nfunc, nin, nout)),
raise ValueError(
"Creation of automatic inplace elemwise operations deprecated"
)
else:
scalar_op = getattr(scalar, symbolname)
rval = Elemwise(scalar_op, nfunc_spec=(nfunc and (nfunc, nin, nout)))
scalar_op = getattr(scalar, symbolname)
rval = Elemwise(scalar_op, nfunc_spec=(nfunc and (nfunc, nin, nout)))
if getattr(symbol, "__doc__"):
rval.__doc__ = symbol.__doc__
......
from pytensor import printing
from pytensor.printing import pprint
from pytensor.tensor.elemwise import scalar_elemwise
@scalar_elemwise
def lt_inplace(a, b):
"""a < b (inplace on a)"""
@scalar_elemwise
def gt_inplace(a, b):
"""a > b (inplace on a)"""
@scalar_elemwise
def le_inplace(a, b):
"""a <= b (inplace on a)"""
@scalar_elemwise
def ge_inplace(a, b):
"""a >= b (inplace on a)"""
@scalar_elemwise
def eq_inplace(a, b):
"""a == b (inplace on a)"""
@scalar_elemwise
def neq_inplace(a, b):
"""a != b (inplace on a)"""
@scalar_elemwise
def and__inplace(a, b):
"""bitwise a & b (inplace on a)"""
@scalar_elemwise
def or__inplace(a, b):
"""bitwise a | b (inplace on a)"""
@scalar_elemwise
def xor_inplace(a, b):
"""bitwise a ^ b (inplace on a)"""
@scalar_elemwise
def invert_inplace(a):
"""bitwise ~a (inplace on a)"""
@scalar_elemwise
def abs_inplace(a):
"""|`a`| (inplace on `a`)"""
@scalar_elemwise
def exp_inplace(a):
"""e^`a` (inplace on `a`)"""
@scalar_elemwise
def exp2_inplace(a):
"""2^`a` (inplace on `a`)"""
@scalar_elemwise
def expm1_inplace(a):
"""e^`a` - 1 (inplace on `a`)"""
@scalar_elemwise
def neg_inplace(a):
"""-a (inplace on a)"""
@scalar_elemwise
def reciprocal_inplace(a):
"""1.0/a (inplace on a)"""
@scalar_elemwise
def log_inplace(a):
"""base e logarithm of a (inplace on a)"""
@scalar_elemwise
def log1p_inplace(a):
"""log(1+a)"""
@scalar_elemwise
def log2_inplace(a):
"""base 2 logarithm of a (inplace on a)"""
@scalar_elemwise
def log10_inplace(a):
"""base 10 logarithm of a (inplace on a)"""
@scalar_elemwise
def sign_inplace(a):
"""sign of `a` (inplace on `a`)"""
@scalar_elemwise
def ceil_inplace(a):
"""ceil of `a` (inplace on `a`)"""
@scalar_elemwise
def floor_inplace(a):
"""floor of `a` (inplace on `a`)"""
@scalar_elemwise
def trunc_inplace(a):
"""trunc of `a` (inplace on `a`)"""
@scalar_elemwise
def round_half_to_even_inplace(a):
"""round_half_to_even_inplace(a) (inplace on `a`)"""
@scalar_elemwise
def round_half_away_from_zero_inplace(a):
"""round_half_away_from_zero_inplace(a) (inplace on `a`)"""
@scalar_elemwise
def sqr_inplace(a):
"""square of `a` (inplace on `a`)"""
@scalar_elemwise
def sqrt_inplace(a):
"""square root of `a` (inplace on `a`)"""
@scalar_elemwise
def deg2rad_inplace(a):
"""convert degree `a` to radian(inplace on `a`)"""
@scalar_elemwise
def rad2deg_inplace(a):
"""convert radian `a` to degree(inplace on `a`)"""
@scalar_elemwise
def cos_inplace(a):
"""cosine of `a` (inplace on `a`)"""
@scalar_elemwise
def arccos_inplace(a):
"""arccosine of `a` (inplace on `a`)"""
@scalar_elemwise
def sin_inplace(a):
"""sine of `a` (inplace on `a`)"""
@scalar_elemwise
def arcsin_inplace(a):
"""arcsine of `a` (inplace on `a`)"""
@scalar_elemwise
def tan_inplace(a):
"""tangent of `a` (inplace on `a`)"""
@scalar_elemwise
def arctan_inplace(a):
"""arctangent of `a` (inplace on `a`)"""
@scalar_elemwise
def arctan2_inplace(a, b):
"""arctangent of `a` / `b` (inplace on `a`)"""
@scalar_elemwise
def cosh_inplace(a):
"""hyperbolic cosine of `a` (inplace on `a`)"""
@scalar_elemwise
def arccosh_inplace(a):
"""hyperbolic arc cosine of `a` (inplace on `a`)"""
@scalar_elemwise
def sinh_inplace(a):
"""hyperbolic sine of `a` (inplace on `a`)"""
@scalar_elemwise
def arcsinh_inplace(a):
"""hyperbolic arc sine of `a` (inplace on `a`)"""
@scalar_elemwise
def tanh_inplace(a):
"""hyperbolic tangent of `a` (inplace on `a`)"""
@scalar_elemwise
def arctanh_inplace(a):
"""hyperbolic arc tangent of `a` (inplace on `a`)"""
@scalar_elemwise
def erf_inplace(a):
"""error function"""
@scalar_elemwise
def erfc_inplace(a):
"""complementary error function"""
@scalar_elemwise
def erfcx_inplace(a):
"""scaled complementary error function"""
@scalar_elemwise
def owens_t_inplace(h, a):
"""owens t function"""
@scalar_elemwise
def gamma_inplace(a):
"""gamma function"""
@scalar_elemwise
def gammaln_inplace(a):
"""log gamma function"""
@scalar_elemwise
def psi_inplace(a):
"""derivative of log gamma function"""
@scalar_elemwise
def tri_gamma_inplace(a):
"""second derivative of the log gamma function"""
@scalar_elemwise
def gammainc_inplace(k, x):
"""regularized lower gamma function (P)"""
@scalar_elemwise
def gammaincc_inplace(k, x):
"""regularized upper gamma function (Q)"""
@scalar_elemwise
def gammau_inplace(k, x):
"""upper incomplete gamma function"""
@scalar_elemwise
def gammal_inplace(k, x):
"""lower incomplete gamma function"""
@scalar_elemwise
def gammaincinv_inplace(k, x):
"""Inverse to the regularized lower incomplete gamma function"""
@scalar_elemwise
def gammainccinv_inplace(k, x):
"""Inverse of the regularized upper incomplete gamma function"""
@scalar_elemwise
def j0_inplace(x):
"""Bessel function of the first kind of order 0."""
@scalar_elemwise
def j1_inplace(x):
"""Bessel function of the first kind of order 1."""
@scalar_elemwise
def jv_inplace(v, x):
"""Bessel function of the first kind of order v (real)."""
@scalar_elemwise
def i0_inplace(x):
"""Modified Bessel function of the first kind of order 0."""
@scalar_elemwise
def i1_inplace(x):
"""Modified Bessel function of the first kind of order 1."""
@scalar_elemwise
def iv_inplace(v, x):
"""Modified Bessel function of the first kind of order v (real)."""
@scalar_elemwise
def ive_inplace(v, x):
"""Exponentially scaled modified Bessel function of the first kind of order v (real)."""
@scalar_elemwise
def sigmoid_inplace(x):
"""Logistic sigmoid function (1 / (1 + exp(-x)), also known as expit or inverse logit"""
@scalar_elemwise
def softplus_inplace(x):
"""Compute log(1 + exp(x)), also known as softplus or log1pexp"""
@scalar_elemwise
def log1mexp_inplace(x):
"""Compute log(1 - exp(x)), also known as log1mexp"""
@scalar_elemwise
def betainc_inplace(a, b, x):
"""Regularized incomplete beta function"""
@scalar_elemwise
def betaincinv_inplace(a, b, x):
"""Inverse of the regularized incomplete beta function"""
@scalar_elemwise
def second_inplace(a):
"""Fill `a` with `b`"""
fill_inplace = second_inplace
pprint.assign(fill_inplace, printing.FunctionPrinter(["fill="]))
@scalar_elemwise
def maximum_inplace(a, b):
"""elementwise addition (inplace on `a`)"""
@scalar_elemwise
def minimum_inplace(a, b):
"""elementwise addition (inplace on `a`)"""
@scalar_elemwise
def add_inplace(a, b):
"""elementwise addition (inplace on `a`)"""
@scalar_elemwise
def sub_inplace(a, b):
"""elementwise subtraction (inplace on `a`)"""
@scalar_elemwise
def mul_inplace(a, b):
"""elementwise multiplication (inplace on `a`)"""
@scalar_elemwise
def true_div_inplace(a, b):
"""elementwise division (inplace on `a`)"""
@scalar_elemwise
def int_div_inplace(a, b):
"""elementwise division (inplace on `a`)"""
@scalar_elemwise
def mod_inplace(a, b):
"""elementwise modulo (inplace on `a`)"""
@scalar_elemwise
def pow_inplace(a, b):
"""elementwise power (inplace on `a`)"""
@scalar_elemwise
def conj_inplace(a):
"""elementwise conjugate (inplace on `a`)"""
@scalar_elemwise
def hyp2f1_inplace(a, b, c, z):
"""gaussian hypergeometric function"""
pprint.assign(add_inplace, printing.OperatorPrinter("+=", -2, "either"))
pprint.assign(mul_inplace, printing.OperatorPrinter("*=", -1, "either"))
pprint.assign(sub_inplace, printing.OperatorPrinter("-=", -2, "left"))
pprint.assign(neg_inplace, printing.OperatorPrinter("-=", 0, "either"))
pprint.assign(true_div_inplace, printing.OperatorPrinter("/=", -1, "left"))
pprint.assign(int_div_inplace, printing.OperatorPrinter("//=", -1, "left"))
pprint.assign(pow_inplace, printing.OperatorPrinter("**=", 1, "right"))
def transpose_inplace(x, **kwargs):
"Perform a transpose on a tensor without copying the underlying storage"
dims = list(range(x.ndim - 1, -1, -1))
return x.dimshuffle(dims)
......@@ -6,13 +6,13 @@ import scipy.special
import pytensor
import pytensor.tensor as pt
import pytensor.tensor.inplace as pti
import pytensor.tensor.math as ptm
from pytensor import config, function
from pytensor.compile import get_mode
from pytensor.compile.ops import deep_copy_op
from pytensor.gradient import grad
from pytensor.scalar import Composite, float64
from pytensor.scalar import add as scalar_add
from pytensor.tensor import blas, tensor
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.math import All, Any, Max, Min, Prod, ProdWithoutZeros, Sum
......@@ -30,6 +30,8 @@ from tests.tensor.test_elemwise import (
rng = np.random.default_rng(42849)
add_inplace = Elemwise(scalar_add, {0: 0})
@pytest.mark.parametrize(
"inputs, input_vals, output_fn",
......@@ -80,7 +82,7 @@ rng = np.random.default_rng(42849)
np.array(1.0, dtype=config.floatX),
np.array(1.0, dtype=config.floatX),
],
lambda x, y: pti.add_inplace(deep_copy_op(x), deep_copy_op(y)),
lambda x, y: add_inplace(deep_copy_op(x), deep_copy_op(y)),
),
(
[pt.vector(), pt.vector()],
......@@ -88,7 +90,7 @@ rng = np.random.default_rng(42849)
rng.standard_normal(100).astype(config.floatX),
rng.standard_normal(100).astype(config.floatX),
],
lambda x, y: pti.add_inplace(deep_copy_op(x), deep_copy_op(y)),
lambda x, y: add_inplace(deep_copy_op(x), deep_copy_op(y)),
),
(
[pt.vector(), pt.vector()],
......
......@@ -31,7 +31,6 @@ from pytensor.graph.rewriting.utils import is_same_graph, rewrite_graph
from pytensor.graph.traversal import ancestors
from pytensor.printing import debugprint
from pytensor.scalar import PolyGamma, Psi, TriGamma
from pytensor.tensor import inplace
from pytensor.tensor.basic import Alloc, constant, join, second, switch
from pytensor.tensor.blas import Dot22, Gemv
from pytensor.tensor.blas_c import CGemv
......@@ -1134,15 +1133,15 @@ def test_log1p():
f = function([x], log(1 + (x)), mode=m)
assert [node.op for node in f.maker.fgraph.toposort()] == [log1p]
f = function([x], log(1 + (-x)), mode=m)
assert [node.op for node in f.maker.fgraph.toposort()] == [
neg,
inplace.log1p_inplace,
assert [node.op.scalar_op for node in f.maker.fgraph.toposort()] == [
ps.neg,
ps.log1p,
]
f = function([x], -log(1 + (-x)), mode=m)
assert [node.op for node in f.maker.fgraph.toposort()] == [
neg,
inplace.log1p_inplace,
inplace.neg_inplace,
assert [node.op.scalar_op for node in f.maker.fgraph.toposort()] == [
ps.neg,
ps.log1p,
ps.neg,
]
# check trickier cases (and use different dtype)
......@@ -4035,27 +4034,27 @@ class TestSigmoidRewrites:
# todo: solve issue #4589 first
# assert check_stack_trace(
# f, ops_to_check=[sigmoid, neg_inplace])
assert [node.op for node in f.maker.fgraph.toposort()] == [
sigmoid,
inplace.neg_inplace,
assert [node.op.scalar_op for node in f.maker.fgraph.toposort()] == [
ps.sigmoid,
ps.neg,
]
f(data)
f = pytensor.function([x], pt.fill(x, -1.0) / (1 - exp(-x)), mode=m)
assert [node.op for node in f.maker.fgraph.toposort()] != [
sigmoid,
inplace.neg_inplace,
assert [node.op.scalar_op for node in f.maker.fgraph.toposort()] != [
ps.sigmoid,
ps.neg,
]
f(data)
f = pytensor.function([x], pt.fill(x, -1.0) / (2 + exp(-x)), mode=m)
assert [node.op for node in f.maker.fgraph.toposort()] != [
sigmoid,
inplace.neg_inplace,
assert [node.op.scalar_op for node in f.maker.fgraph.toposort()] != [
ps.sigmoid,
ps.neg,
]
f(data)
f = pytensor.function([x], pt.fill(x, -1.1) / (1 + exp(-x)), mode=m)
assert [node.op for node in f.maker.fgraph.toposort()] != [
sigmoid,
inplace.neg_inplace,
assert [node.op.scalar_op for node in f.maker.fgraph.toposort()] != [
ps.sigmoid,
ps.neg,
]
f(data)
......@@ -4077,10 +4076,10 @@ class TestSigmoidRewrites:
(pt.fill(x, -1.1) * exp(x)) / ((1 + exp(x)) * (1 + exp(-x))),
mode=m,
)
assert [node.op for node in f.maker.fgraph.toposort()] != [
sigmoid,
mul,
inplace.neg_inplace,
assert [node.op.scalar_op for node in f.maker.fgraph.toposort()] != [
ps.sigmoid,
ps.mul,
ps.neg,
]
f(data)
f = pytensor.function(
......@@ -4088,10 +4087,10 @@ class TestSigmoidRewrites:
(pt.fill(x, -1.0) * exp(x)) / ((2 + exp(x)) * (1 + exp(-x))),
mode=m,
)
assert [node.op for node in f.maker.fgraph.toposort()] != [
sigmoid,
mul,
inplace.neg_inplace,
assert [node.op.scalar_op for node in f.maker.fgraph.toposort()] != [
ps.sigmoid,
ps.mul,
ps.neg,
]
f(data)
f = pytensor.function(
......@@ -4099,10 +4098,10 @@ class TestSigmoidRewrites:
(pt.fill(x, -1.0) * exp(x)) / ((1 + exp(x)) * (2 + exp(-x))),
mode=m,
)
assert [node.op for node in f.maker.fgraph.toposort()] != [
sigmoid,
mul,
inplace.neg_inplace,
assert [node.op.scalar_op for node in f.maker.fgraph.toposort()] != [
ps.sigmoid,
ps.mul,
ps.neg,
]
f(data)
f = pytensor.function(
......@@ -4110,10 +4109,10 @@ class TestSigmoidRewrites:
(pt.fill(x, -1.0) * exp(x)) / ((1 + exp(x)) * (1 + exp(x))),
mode=m,
)
assert [node.op for node in f.maker.fgraph.toposort()] != [
sigmoid,
mul,
inplace.neg_inplace,
assert [node.op.scalar_op for node in f.maker.fgraph.toposort()] != [
ps.sigmoid,
ps.mul,
ps.neg,
]
f(data)
f = pytensor.function(
......@@ -4121,10 +4120,10 @@ class TestSigmoidRewrites:
(pt.fill(x, -1.0) * exp(x)) / ((1 + exp(x)) * (2 + exp(-x))),
mode=m,
)
assert [node.op for node in f.maker.fgraph.toposort()] != [
sigmoid,
mul,
inplace.neg_inplace,
assert [node.op.scalar_op for node in f.maker.fgraph.toposort()] != [
ps.sigmoid,
ps.mul,
ps.neg,
]
f(data)
......
......@@ -17,7 +17,6 @@ from pytensor.configdefaults import config
from pytensor.gradient import grad
from pytensor.graph.rewriting.basic import in2out
from pytensor.graph.utils import InconsistencyError
from pytensor.tensor import inplace
from pytensor.tensor.basic import as_tensor_variable
from pytensor.tensor.blas import (
BatchedDot,
......@@ -40,6 +39,7 @@ from pytensor.tensor.blas import (
ger,
ger_destructive,
)
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import Dot, dot, mean, mul, outer, sigmoid
from pytensor.tensor.rewriting.blas import local_dot22_to_dot22scalar, local_gemm_to_ger
from pytensor.tensor.type import (
......@@ -258,16 +258,20 @@ class TestGemm:
rng = np.random.default_rng(seed=utt.fetch_seed())
Z = as_tensor_variable(rng.random((2, 2)))
A = as_tensor_variable(rng.random((2, 2)))
Zt = Z.transpose()
assert isinstance(Zt.owner.op, DimShuffle) and Zt.owner.op.view_map == {0: [0]}
with pytest.raises(InconsistencyError, match=Gemm.E_z_uniq):
gemm_inplace(Z, 1.0, A, inplace.transpose_inplace(Z), 1.0)
gemm_inplace(Z, 1.0, A, Zt, 1.0)
def test_destroy_map2(self):
# test that only first input can be overwritten.
rng = np.random.default_rng(seed=utt.fetch_seed())
Z = as_tensor_variable(rng.random((2, 2)))
A = as_tensor_variable(rng.random((2, 2)))
Zt = Z.transpose()
assert isinstance(Zt.owner.op, DimShuffle) and Zt.owner.op.view_map == {0: [0]}
with pytest.raises(InconsistencyError, match=Gemm.E_z_uniq):
gemm_inplace(Z, 1.0, inplace.transpose_inplace(Z), A, 1.0)
gemm_inplace(Z, 1.0, Zt, A, 1.0)
def test_destroy_map3(self):
# test that only first input can be overwritten
......
......@@ -20,6 +20,9 @@ from pytensor.graph.replace import vectorize_node
from pytensor.link.basic import PerformLinker
from pytensor.link.c.basic import CLinker, OpWiseCLinker
from pytensor.scalar import ScalarOp, float32, float64, int32, int64
from pytensor.scalar import add as scalar_add
from pytensor.scalar import exp as scalar_exp
from pytensor.scalar import xor as scalar_xor
from pytensor.tensor import as_tensor_variable
from pytensor.tensor.basic import get_scalar_constant_value, second
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
......@@ -43,6 +46,16 @@ from pytensor.tensor.type import (
)
from tests import unittest_tools
from tests.link.test_link import make_function
from tests.tensor.utils import (
_bad_runtime_broadcast_binary_normal,
inplace_func,
integers,
integers_uint16,
integers_uint32,
makeBroadcastTester,
random,
random_complex,
)
def reduce_bitwise_and(x, axis=-1, dtype="int8"):
......@@ -334,7 +347,7 @@ class TestBroadcast:
x = x_type("x")
y = y_type("y")
e = op(ps.Add(ps.transfer_type(0)), {0: 0})(x, y)
e = op(ps.add, {0: 0})(x, y)
f = make_function(copy(linker).accept(FunctionGraph([x, y], [e])))
xv = rand_val(xsh)
yv = rand_val(ysh)
......@@ -348,7 +361,7 @@ class TestBroadcast:
if isinstance(linker, PerformLinker):
x = x_type("x")
y = y_type("y")
e = op(ps.Add(ps.transfer_type(0)), {0: 0})(x, y)
e = op(ps.add, {0: 0})(x, y)
f = make_function(copy(linker).accept(FunctionGraph([x, y], [e.shape])))
xv = rand_val(xsh)
yv = rand_val(ysh)
......@@ -390,7 +403,10 @@ class TestBroadcast:
):
x = t(pytensor.config.floatX, shape=(None, None))("x")
y = t(pytensor.config.floatX, shape=(1, 1))("y")
e = op(ps.Second(ps.transfer_type(0)), {0: 0})(x, y)
op1 = op(ps.second, {0: 0})
op2 = op(ps.second, {0: 0})
assert op1 == op2
e = op(ps.Second(), {0: 0})(x, y)
f = make_function(linker().accept(FunctionGraph([x, y], [e])))
xv = rval((5, 5))
yv = rval((1, 1))
......@@ -1113,3 +1129,74 @@ def test_numpy_warning_suppressed():
y = pt.log(x)
fn = pytensor.function([x], y, mode=Mode(linker="py"))
assert fn(0) == -np.inf
rng = np.random.default_rng(18)
_good_add_inplace = dict(
same_shapes=(random(2, 3, rng=rng), random(2, 3, rng=rng)),
not_same_dimensions=(random(2, 2, rng=rng), random(2, rng=rng)),
scalar=(random(2, 3, rng=rng), random(1, 1, rng=rng)),
row=(random(2, 3, rng=rng), random(1, 3, rng=rng)),
column=(random(2, 3, rng=rng), random(2, 1, rng=rng)),
integers=(integers(2, 3, rng=rng), integers(2, 3, rng=rng)),
uint32=(integers_uint32(2, 3, rng=rng), integers_uint32(2, 3, rng=rng)),
uint16=(integers_uint16(2, 3, rng=rng), integers_uint16(2, 3, rng=rng)),
# (float32, >int16) upcasts to float64 by default
dtype_valid_mixup=(
random(2, 3, rng=rng),
integers(2, 3, rng=rng).astype(
"int16" if config.floatX == "float32" else "int64"
),
),
complex1=(random_complex(2, 3, rng=rng), random_complex(2, 3, rng=rng)),
complex2=(random_complex(2, 3, rng=rng), random(2, 3, rng=rng)),
empty=(np.asarray([], dtype=config.floatX), np.asarray([1], dtype=config.floatX)),
)
TestAddInplaceBroadcast = makeBroadcastTester(
op=Elemwise(scalar_add, {0: 0}),
expected=lambda x, y: x + y,
good=_good_add_inplace,
# Cannot inplace on first input if it doesn't match output dtype (upcast of inputs)
bad_build=dict(dtype_invalid_mixup=_good_add_inplace["dtype_valid_mixup"][::-1]),
bad_runtime=_bad_runtime_broadcast_binary_normal,
inplace=True,
)
@pytest.mark.xfail(
config.cycle_detection == "fast" and config.mode != "FAST_COMPILE",
reason="Cycle detection is fast and mode is FAST_COMPILE",
)
def test_exp_inplace_grad_1():
utt.verify_grad(
Elemwise(scalar_exp, {0: 0}),
[
np.asarray(
[
[1.5089518, 1.48439076, -4.7820262],
[2.04832468, 0.50791564, -1.58892269],
]
)
],
)
def test_XOR_inplace():
dtype = [
"int8",
"int16",
"int32",
"int64",
]
xor_inplace = Elemwise(scalar_xor, {0: 0})
for dtype in dtype:
x, y = vector(dtype=dtype), vector(dtype=dtype)
l = np.asarray([0, 0, 1, 1], dtype=dtype)
r = np.asarray([0, 1, 0, 1], dtype=dtype)
ix = x
ix = xor_inplace(ix, y)
gn = inplace_func([x, y], ix)
_ = gn(l, r)
# test the in-place stuff
assert np.all(l == np.asarray([0, 1, 1, 0])), l
import numpy as np
import pytest
from pytensor import config
from pytensor.scalar.basic import round_half_away_from_zero_vec, upcast
from pytensor.tensor.inplace import (
abs_inplace,
add_inplace,
arccos_inplace,
arccosh_inplace,
arcsin_inplace,
arcsinh_inplace,
arctan2_inplace,
arctan_inplace,
arctanh_inplace,
ceil_inplace,
conj_inplace,
cos_inplace,
cosh_inplace,
deg2rad_inplace,
exp2_inplace,
exp_inplace,
expm1_inplace,
floor_inplace,
int_div_inplace,
log1p_inplace,
log2_inplace,
log10_inplace,
log_inplace,
maximum_inplace,
minimum_inplace,
mod_inplace,
mul_inplace,
neg_inplace,
pow_inplace,
rad2deg_inplace,
reciprocal_inplace,
round_half_away_from_zero_inplace,
round_half_to_even_inplace,
sign_inplace,
sin_inplace,
sinh_inplace,
sqr_inplace,
sqrt_inplace,
sub_inplace,
tan_inplace,
tanh_inplace,
true_div_inplace,
trunc_inplace,
xor_inplace,
)
from pytensor.tensor.type import vector
from tests import unittest_tools as utt
from tests.tensor.utils import (
_bad_build_broadcast_binary_normal,
_bad_runtime_broadcast_binary_normal,
_bad_runtime_reciprocal,
_good_broadcast_binary_arctan2,
_good_broadcast_binary_normal,
_good_broadcast_div_mod_normal_float_inplace,
_good_broadcast_pow_normal_float_pow,
_good_broadcast_unary_arccosh,
_good_broadcast_unary_arcsin_float,
_good_broadcast_unary_arctanh,
_good_broadcast_unary_normal,
_good_broadcast_unary_normal_abs,
_good_broadcast_unary_normal_float,
_good_broadcast_unary_normal_float_no_complex,
_good_broadcast_unary_normal_float_no_empty_no_complex,
_good_broadcast_unary_normal_no_complex,
_good_broadcast_unary_positive_float,
_good_broadcast_unary_tan,
_good_broadcast_unary_wide_float,
_good_reciprocal_inplace,
_numpy_true_div,
angle_eps,
check_floatX,
copymod,
div_grad_rtol,
ignore_isfinite_mode,
inplace_func,
makeBroadcastTester,
upcast_float16_ufunc,
)
TestAddInplaceBroadcast = makeBroadcastTester(
op=add_inplace,
expected=lambda x, y: x + y,
good=_good_broadcast_binary_normal,
bad_build=_bad_build_broadcast_binary_normal,
bad_runtime=_bad_runtime_broadcast_binary_normal,
inplace=True,
)
TestSubInplaceBroadcast = makeBroadcastTester(
op=sub_inplace,
expected=lambda x, y: x - y,
good=_good_broadcast_binary_normal,
bad_build=_bad_build_broadcast_binary_normal,
bad_runtime=_bad_runtime_broadcast_binary_normal,
inplace=True,
)
TestMaximumInplaceBroadcast = makeBroadcastTester(
op=maximum_inplace,
expected=np.maximum,
good=_good_broadcast_binary_normal,
bad_build=_bad_build_broadcast_binary_normal,
bad_runtime=_bad_runtime_broadcast_binary_normal,
inplace=True,
)
TestMinimumInplaceBroadcast = makeBroadcastTester(
op=minimum_inplace,
expected=np.minimum,
good=_good_broadcast_binary_normal,
bad_build=_bad_build_broadcast_binary_normal,
bad_runtime=_bad_runtime_broadcast_binary_normal,
inplace=True,
)
TestMulInplaceBroadcast = makeBroadcastTester(
op=mul_inplace,
expected=lambda x, y: x * y,
good=_good_broadcast_binary_normal,
bad_build=_bad_build_broadcast_binary_normal,
bad_runtime=_bad_runtime_broadcast_binary_normal,
inplace=True,
)
TestTrueDivInplaceBroadcast = makeBroadcastTester(
op=true_div_inplace,
expected=_numpy_true_div,
good=copymod(
_good_broadcast_div_mod_normal_float_inplace,
# The output is now in float, we cannot work inplace on an int.
without=["integer", "uint8", "uint16", "int8"],
),
grad_rtol=div_grad_rtol,
inplace=True,
)
TestReciprocalInplaceBroadcast = makeBroadcastTester(
op=reciprocal_inplace,
expected=lambda x: _numpy_true_div(np.int8(1), x),
good=_good_reciprocal_inplace,
bad_runtime=_bad_runtime_reciprocal,
grad_rtol=div_grad_rtol,
inplace=True,
)
TestModInplaceBroadcast = makeBroadcastTester(
op=mod_inplace,
expected=lambda x, y: np.asarray(x % y, dtype=upcast(x.dtype, y.dtype)),
good=copymod(
_good_broadcast_div_mod_normal_float_inplace, ["complex1", "complex2"]
),
grad_eps=1e-5,
inplace=True,
)
TestPowInplaceBroadcast = makeBroadcastTester(
op=pow_inplace,
expected=lambda x, y: x**y,
good=_good_broadcast_pow_normal_float_pow,
inplace=True,
mode=ignore_isfinite_mode,
)
TestNegInplaceBroadcast = makeBroadcastTester(
op=neg_inplace,
expected=lambda x: -x,
good=_good_broadcast_unary_normal,
inplace=True,
)
TestSgnInplaceBroadcast = makeBroadcastTester(
op=sign_inplace,
expected=np.sign,
good=_good_broadcast_unary_normal_no_complex,
inplace=True,
)
TestAbsInplaceBroadcast = makeBroadcastTester(
op=abs_inplace,
expected=lambda x: np.abs(x),
good=_good_broadcast_unary_normal_abs,
inplace=True,
)
TestIntDivInplaceBroadcast = makeBroadcastTester(
op=int_div_inplace,
expected=lambda x, y: check_floatX((x, y), x // y),
good=_good_broadcast_div_mod_normal_float_inplace,
# I don't test the grad as the output is always an integer
# (this is not a continuous output).
# grad=_grad_broadcast_div_mod_normal,
inplace=True,
)
TestCeilInplaceBroadcast = makeBroadcastTester(
op=ceil_inplace,
expected=upcast_float16_ufunc(np.ceil),
good=copymod(
_good_broadcast_unary_normal_no_complex,
without=["integers", "int8", "uint8", "uint16"],
),
# corner cases includes a lot of integers: points where Ceil is not
# continuous (not differentiable)
inplace=True,
)
TestFloorInplaceBroadcast = makeBroadcastTester(
op=floor_inplace,
expected=upcast_float16_ufunc(np.floor),
good=copymod(
_good_broadcast_unary_normal_no_complex,
without=["integers", "int8", "uint8", "uint16"],
),
inplace=True,
)
TestTruncInplaceBroadcast = makeBroadcastTester(
op=trunc_inplace,
expected=upcast_float16_ufunc(np.trunc),
good=_good_broadcast_unary_normal_no_complex,
inplace=True,
)
TestRoundHalfToEvenInplaceBroadcast = makeBroadcastTester(
op=round_half_to_even_inplace,
expected=np.round,
good=_good_broadcast_unary_normal_float_no_complex,
inplace=True,
)
TestRoundHalfAwayFromZeroInplaceBroadcast = makeBroadcastTester(
op=round_half_away_from_zero_inplace,
expected=lambda a: round_half_away_from_zero_vec(a),
good=_good_broadcast_unary_normal_float_no_empty_no_complex,
inplace=True,
)
TestSqrInplaceBroadcast = makeBroadcastTester(
op=sqr_inplace,
expected=np.square,
good=_good_broadcast_unary_normal,
inplace=True,
)
TestExpInplaceBroadcast = makeBroadcastTester(
op=exp_inplace,
expected=np.exp,
good=_good_broadcast_unary_normal_float,
inplace=True,
)
TestExp2InplaceBroadcast = makeBroadcastTester(
op=exp2_inplace,
expected=np.exp2,
good=_good_broadcast_unary_normal_float,
inplace=True,
)
TestExpm1InplaceBroadcast = makeBroadcastTester(
op=expm1_inplace,
expected=np.expm1,
good=_good_broadcast_unary_normal_float,
inplace=True,
)
TestLogInplaceBroadcast = makeBroadcastTester(
op=log_inplace,
expected=np.log,
good=_good_broadcast_unary_positive_float,
inplace=True,
)
TestLog2InplaceBroadcast = makeBroadcastTester(
op=log2_inplace,
expected=np.log2,
good=_good_broadcast_unary_positive_float,
inplace=True,
)
TestLog10InplaceBroadcast = makeBroadcastTester(
op=log10_inplace,
expected=np.log10,
good=_good_broadcast_unary_positive_float,
inplace=True,
)
TestLog1pInplaceBroadcast = makeBroadcastTester(
op=log1p_inplace,
expected=np.log1p,
good=_good_broadcast_unary_positive_float,
inplace=True,
)
TestSqrtInplaceBroadcast = makeBroadcastTester(
op=sqrt_inplace,
expected=np.sqrt,
good=_good_broadcast_unary_positive_float,
inplace=True,
)
TestDeg2radInplaceBroadcast = makeBroadcastTester(
op=deg2rad_inplace,
expected=np.deg2rad,
good=_good_broadcast_unary_normal_float_no_complex,
inplace=True,
eps=angle_eps,
)
TestRad2degInplaceBroadcast = makeBroadcastTester(
op=rad2deg_inplace,
expected=np.rad2deg,
good=_good_broadcast_unary_normal_float_no_complex,
inplace=True,
eps=angle_eps,
)
TestSinInplaceBroadcast = makeBroadcastTester(
op=sin_inplace,
expected=np.sin,
good=_good_broadcast_unary_wide_float,
inplace=True,
)
TestArcsinInplaceBroadcast = makeBroadcastTester(
op=arcsin_inplace,
expected=np.arcsin,
good=_good_broadcast_unary_arcsin_float,
inplace=True,
)
TestCosInplaceBroadcast = makeBroadcastTester(
op=cos_inplace,
expected=np.cos,
good=_good_broadcast_unary_wide_float,
inplace=True,
)
TestArccosInplaceBroadcast = makeBroadcastTester(
op=arccos_inplace,
expected=np.arccos,
good=_good_broadcast_unary_arcsin_float,
inplace=True,
)
TestTanInplaceBroadcast = makeBroadcastTester(
op=tan_inplace,
expected=np.tan,
good=copymod(
_good_broadcast_unary_tan, without=["integers", "int8", "uint8", "uint16"]
),
inplace=True,
)
TestArctanInplaceBroadcast = makeBroadcastTester(
op=arctan_inplace,
expected=np.arctan,
good=_good_broadcast_unary_wide_float,
inplace=True,
)
TestArctan2InplaceBroadcast = makeBroadcastTester(
op=arctan2_inplace,
expected=np.arctan2,
good=copymod(
_good_broadcast_binary_arctan2,
without=["integers", "int8", "uint8", "uint16", "dtype_mixup_2"],
),
inplace=True,
)
TestCoshInplaceBroadcast = makeBroadcastTester(
op=cosh_inplace,
expected=np.cosh,
good=_good_broadcast_unary_normal_float,
inplace=True,
)
TestArccoshInplaceBroadcast = makeBroadcastTester(
op=arccosh_inplace,
expected=np.arccosh,
good=copymod(_good_broadcast_unary_arccosh, without=["integers", "uint8"]),
inplace=True,
)
TestSinhInplaceBroadcast = makeBroadcastTester(
op=sinh_inplace,
expected=np.sinh,
good=_good_broadcast_unary_normal_float,
inplace=True,
)
TestArcsinhInplaceBroadcast = makeBroadcastTester(
op=arcsinh_inplace,
expected=np.arcsinh,
good=_good_broadcast_unary_normal_float,
inplace=True,
)
TestTanhInplaceBroadcast = makeBroadcastTester(
op=tanh_inplace,
expected=np.tanh,
good=_good_broadcast_unary_normal_float,
inplace=True,
)
TestArctanhInplaceBroadcast = makeBroadcastTester(
op=arctanh_inplace,
expected=np.arctanh,
good=copymod(
_good_broadcast_unary_arctanh, without=["integers", "int8", "uint8", "uint16"]
),
inplace=True,
)
TestConjInplaceBroadcast = makeBroadcastTester(
op=conj_inplace,
expected=np.conj,
good=_good_broadcast_unary_normal,
inplace=True,
)
@pytest.mark.xfail(
config.cycle_detection == "fast" and config.mode != "FAST_COMPILE",
reason="Cycle detection is fast and mode is FAST_COMPILE",
)
def test_exp_inplace_grad_1():
utt.verify_grad(
exp_inplace,
[
np.asarray(
[
[1.5089518, 1.48439076, -4.7820262],
[2.04832468, 0.50791564, -1.58892269],
]
)
],
)
def test_XOR_inplace():
dtype = [
"int8",
"int16",
"int32",
"int64",
]
for dtype in dtype:
x, y = vector(dtype=dtype), vector(dtype=dtype)
l = np.asarray([0, 0, 1, 1], dtype=dtype)
r = np.asarray([0, 1, 0, 1], dtype=dtype)
ix = x
ix = xor_inplace(ix, y)
gn = inplace_func([x, y], ix)
_ = gn(l, r)
# test the in-place stuff
assert np.all(l == np.asarray([0, 1, 1, 0])), l
......@@ -12,13 +12,12 @@ from pytensor.compile.mode import get_default_mode
from pytensor.configdefaults import config
from pytensor.gradient import NullTypeGradError, verify_grad
from pytensor.scalar import ScalarLoop
from pytensor.tensor import gammaincc, inplace, kn, kv, kve, vector
from pytensor.tensor import gammaincc, kn, kv, kve, vector
from pytensor.tensor.elemwise import Elemwise
from tests import unittest_tools as utt
from tests.tensor.utils import (
_good_broadcast_unary_chi2sf,
_good_broadcast_unary_normal,
_good_broadcast_unary_normal_float,
_good_broadcast_unary_normal_float_no_complex,
_good_broadcast_unary_normal_float_no_complex_small_neg_range,
_good_broadcast_unary_normal_no_complex,
......@@ -85,14 +84,6 @@ TestErfBroadcast = makeBroadcastTester(
eps=2e-10,
mode=mode_no_scipy,
)
TestErfInplaceBroadcast = makeBroadcastTester(
op=inplace.erf_inplace,
expected=expected_erf,
good=_good_broadcast_unary_normal_float,
mode=mode_no_scipy,
eps=2e-10,
inplace=True,
)
TestErfcBroadcast = makeBroadcastTester(
op=pt.erfc,
......@@ -102,14 +93,6 @@ TestErfcBroadcast = makeBroadcastTester(
eps=2e-10,
mode=mode_no_scipy,
)
TestErfcInplaceBroadcast = makeBroadcastTester(
op=inplace.erfc_inplace,
expected=expected_erfc,
good=_good_broadcast_unary_normal_float_no_complex,
eps=2e-10,
mode=mode_no_scipy,
inplace=True,
)
TestErfcxBroadcast = makeBroadcastTester(
op=pt.erfcx,
......@@ -119,14 +102,6 @@ TestErfcxBroadcast = makeBroadcastTester(
eps=2e-10,
mode=mode_no_scipy,
)
TestErfcxInplaceBroadcast = makeBroadcastTester(
op=inplace.erfcx_inplace,
expected=expected_erfcx,
good=_good_broadcast_unary_normal_float_no_complex_small_neg_range,
eps=2e-10,
mode=mode_no_scipy,
inplace=True,
)
TestErfinvBroadcast = makeBroadcastTester(
op=pt.erfinv,
......@@ -192,14 +167,6 @@ TestOwensTBroadcast = makeBroadcastTester(
eps=2e-10,
mode=mode_no_scipy,
)
TestOwensTInplaceBroadcast = makeBroadcastTester(
op=inplace.owens_t_inplace,
expected=expected_owenst,
good=_good_broadcast_binary_owenst,
eps=2e-10,
mode=mode_no_scipy,
inplace=True,
)
rng = np.random.default_rng(seed=utt.fetch_seed())
_good_broadcast_unary_gammaln = dict(
......@@ -223,14 +190,6 @@ TestGammaBroadcast = makeBroadcastTester(
mode=mode_no_scipy,
eps=1e-5,
)
TestGammaInplaceBroadcast = makeBroadcastTester(
op=inplace.gamma_inplace,
expected=expected_gamma,
good=_good_broadcast_unary_gammaln,
mode=mode_no_scipy,
eps=1e-5,
inplace=True,
)
TestGammalnBroadcast = makeBroadcastTester(
op=pt.gammaln,
......@@ -240,14 +199,6 @@ TestGammalnBroadcast = makeBroadcastTester(
eps=2e-10,
mode=mode_no_scipy,
)
TestGammalnInplaceBroadcast = makeBroadcastTester(
op=inplace.gammaln_inplace,
expected=expected_gammaln,
good=_good_broadcast_unary_gammaln,
eps=2e-10,
mode=mode_no_scipy,
inplace=True,
)
rng = np.random.default_rng(seed=utt.fetch_seed())
_good_broadcast_unary_psi = dict(
......@@ -265,14 +216,6 @@ TestPsiBroadcast = makeBroadcastTester(
eps=2e-10,
mode=mode_no_scipy,
)
TestPsiInplaceBroadcast = makeBroadcastTester(
op=inplace.psi_inplace,
expected=expected_psi,
good=_good_broadcast_unary_psi,
eps=2e-10,
mode=mode_no_scipy,
inplace=True,
)
_good_broadcast_unary_tri_gamma = _good_broadcast_unary_psi
......@@ -283,14 +226,6 @@ TestTriGammaBroadcast = makeBroadcastTester(
eps=2e-8,
mode=mode_no_scipy,
)
TestTriGammaInplaceBroadcast = makeBroadcastTester(
op=inplace.tri_gamma_inplace,
expected=expected_tri_gamma,
good=_good_broadcast_unary_tri_gamma,
eps=2e-8,
mode=mode_no_scipy,
inplace=True,
)
TestChi2SFBroadcast = makeBroadcastTester(
op=pt.chi2sf,
......@@ -343,15 +278,6 @@ TestGammaIncBroadcast = makeBroadcastTester(
mode=mode_no_scipy,
)
TestGammaIncInplaceBroadcast = makeBroadcastTester(
op=inplace.gammainc_inplace,
expected=expected_gammainc,
good=_good_broadcast_binary_gamma,
eps=2e-8,
mode=mode_no_scipy,
inplace=True,
)
TestGammaInccBroadcast = makeBroadcastTester(
op=pt.gammaincc,
expected=expected_gammaincc,
......@@ -361,15 +287,6 @@ TestGammaInccBroadcast = makeBroadcastTester(
mode=mode_no_scipy,
)
TestGammaInccInplaceBroadcast = makeBroadcastTester(
op=inplace.gammaincc_inplace,
expected=expected_gammaincc,
good=_good_broadcast_binary_gamma,
eps=2e-8,
mode=mode_no_scipy,
inplace=True,
)
def test_gammainc_ddk_tabulated_values():
# This test replicates part of the old STAN test:
......@@ -447,15 +364,6 @@ TestGammaUBroadcast = makeBroadcastTester(
mode=mode_no_scipy,
)
TestGammaUInplaceBroadcast = makeBroadcastTester(
op=inplace.gammau_inplace,
expected=expected_gammau,
good=_good_broadcast_binary_gamma,
eps=2e-8,
mode=mode_no_scipy,
inplace=True,
)
TestGammaLBroadcast = makeBroadcastTester(
op=pt.gammal,
expected=expected_gammal,
......@@ -464,15 +372,6 @@ TestGammaLBroadcast = makeBroadcastTester(
mode=mode_no_scipy,
)
TestGammaLInplaceBroadcast = makeBroadcastTester(
op=inplace.gammal_inplace,
expected=expected_gammal,
good=_good_broadcast_binary_gamma,
eps=2e-8,
mode=mode_no_scipy,
inplace=True,
)
rng = np.random.default_rng(seed=utt.fetch_seed())
_good_broadcast_binary_gamma = dict(
normal=(
......@@ -490,15 +389,6 @@ TestGammaIncInvBroadcast = makeBroadcastTester(
mode=mode_no_scipy,
)
TestGammaIncInvInplaceBroadcast = makeBroadcastTester(
op=inplace.gammaincinv_inplace,
expected=expected_gammaincinv,
good=_good_broadcast_binary_gamma,
eps=2e-8,
mode=mode_no_scipy,
inplace=True,
)
TestGammaInccInvBroadcast = makeBroadcastTester(
op=pt.gammainccinv,
expected=expected_gammainccinv,
......@@ -507,15 +397,6 @@ TestGammaInccInvBroadcast = makeBroadcastTester(
mode=mode_no_scipy,
)
TestGammaInccInvInplaceBroadcast = makeBroadcastTester(
op=inplace.gammainccinv_inplace,
expected=expected_gammainccinv,
good=_good_broadcast_binary_gamma,
eps=2e-8,
mode=mode_no_scipy,
inplace=True,
)
rng = np.random.default_rng(seed=utt.fetch_seed())
_good_broadcast_unary_bessel = dict(
normal=(random_ranged(-10, 10, (2, 3), rng=rng),),
......@@ -562,15 +443,6 @@ TestJ0Broadcast = makeBroadcastTester(
mode=mode_no_scipy,
)
TestJ0InplaceBroadcast = makeBroadcastTester(
op=inplace.j0_inplace,
expected=expected_j0,
good=_good_broadcast_unary_bessel,
eps=2e-10,
mode=mode_no_scipy,
inplace=True,
)
TestJ1Broadcast = makeBroadcastTester(
op=pt.j1,
expected=expected_j1,
......@@ -580,15 +452,6 @@ TestJ1Broadcast = makeBroadcastTester(
mode=mode_no_scipy,
)
TestJ1InplaceBroadcast = makeBroadcastTester(
op=inplace.j1_inplace,
expected=expected_j1,
good=_good_broadcast_unary_bessel,
eps=2e-10,
mode=mode_no_scipy,
inplace=True,
)
TestJvBroadcast = makeBroadcastTester(
op=pt.jv,
expected=expected_jv,
......@@ -597,15 +460,6 @@ TestJvBroadcast = makeBroadcastTester(
mode=mode_no_scipy,
)
TestJvInplaceBroadcast = makeBroadcastTester(
op=inplace.jv_inplace,
expected=expected_jv,
good=_good_broadcast_binary_bessel,
eps=2e-10,
mode=mode_no_scipy,
inplace=True,
)
def test_verify_jv_grad():
# Verify Jv gradient.
......@@ -628,15 +482,6 @@ TestI0Broadcast = makeBroadcastTester(
mode=mode_no_scipy,
)
TestI0InplaceBroadcast = makeBroadcastTester(
op=inplace.i0_inplace,
expected=expected_i0,
good=_good_broadcast_unary_bessel,
eps=2e-10,
mode=mode_no_scipy,
inplace=True,
)
TestI1Broadcast = makeBroadcastTester(
op=pt.i1,
expected=expected_i1,
......@@ -646,15 +491,6 @@ TestI1Broadcast = makeBroadcastTester(
mode=mode_no_scipy,
)
TestI1InplaceBroadcast = makeBroadcastTester(
op=inplace.i1_inplace,
expected=expected_i1,
good=_good_broadcast_unary_bessel,
eps=2e-10,
mode=mode_no_scipy,
inplace=True,
)
TestIvBroadcast = makeBroadcastTester(
op=pt.iv,
expected=expected_iv,
......@@ -663,15 +499,6 @@ TestIvBroadcast = makeBroadcastTester(
mode=mode_no_scipy,
)
TestIvInplaceBroadcast = makeBroadcastTester(
op=inplace.iv_inplace,
expected=expected_iv,
good=_good_broadcast_binary_bessel,
eps=2e-10,
mode=mode_no_scipy,
inplace=True,
)
TestIveBroadcast = makeBroadcastTester(
op=pt.ive,
expected=expected_ive,
......@@ -680,15 +507,6 @@ TestIveBroadcast = makeBroadcastTester(
mode=mode_no_scipy,
)
TestIveInplaceBroadcast = makeBroadcastTester(
op=inplace.ive_inplace,
expected=expected_ive,
good=_good_broadcast_binary_bessel,
eps=2e-10,
mode=mode_no_scipy,
inplace=True,
)
def test_verify_iv_grad():
# Verify Iv gradient.
......@@ -721,15 +539,6 @@ TestSigmoidBroadcast = makeBroadcastTester(
eps=1e-8,
)
TestSigmoidInplaceBroadcast = makeBroadcastTester(
op=inplace.sigmoid_inplace,
expected=expected_sigmoid,
good=_good_broadcast_unary_normal_no_complex,
grad=_grad_broadcast_unary_normal,
eps=1e-8,
inplace=True,
)
class TestSigmoid:
def test_elemwise(self):
......@@ -758,15 +567,6 @@ TestSoftplusBroadcast = makeBroadcastTester(
eps=1e-8,
)
TestSoftplusInplaceBroadcast = makeBroadcastTester(
op=inplace.softplus_inplace,
expected=expected_sofplus,
good=_good_broadcast_unary_softplus,
grad=_grad_broadcast_unary_normal,
eps=1e-8,
inplace=True,
)
class TestSoftplus:
def test_elemwise(self):
......@@ -805,14 +605,6 @@ TestLog1mexpBroadcast = makeBroadcastTester(
eps=1e-8,
)
TestLog1mexpInplaceBroadcast = makeBroadcastTester(
op=inplace.log1mexp_inplace,
expected=expected_log1mexp,
good=_good_broadcast_unary_log1mexp,
eps=1e-8,
inplace=True,
)
_good_broadcast_ternary_betainc = dict(
normal=(
random_ranged(0, 1000, (2, 3)),
......@@ -828,14 +620,6 @@ TestBetaincBroadcast = makeBroadcastTester(
grad=_good_broadcast_ternary_betainc,
)
TestBetaincInplaceBroadcast = makeBroadcastTester(
op=inplace.betainc_inplace,
expected=special.betainc,
good=_good_broadcast_ternary_betainc,
grad=_good_broadcast_ternary_betainc,
inplace=True,
)
class TestBetaIncGrad:
def test_stan_grad_partial(self):
......@@ -926,13 +710,6 @@ TestBetaincinvBroadcast = makeBroadcastTester(
good=_good_broadcast_ternary_betaincinv,
)
TestBetaincinvInplaceBroadcast = makeBroadcastTester(
op=inplace.betaincinv_inplace,
expected=special.betaincinv,
good=_good_broadcast_ternary_betaincinv,
inplace=True,
)
_good_broadcast_quaternary_hyp2f1 = dict(
normal=(
random_ranged(0, 20, (2, 3)),
......@@ -949,13 +726,6 @@ TestHyp2F1Broadcast = makeBroadcastTester(
grad=_good_broadcast_quaternary_hyp2f1,
)
TestHyp2F1InplaceBroadcast = makeBroadcastTester(
op=inplace.hyp2f1_inplace,
expected=expected_hyp2f1,
good=_good_broadcast_quaternary_hyp2f1,
inplace=True,
)
class TestHyp2F1Grad:
few_iters_case = (
......
......@@ -672,7 +672,9 @@ def makeTester(
return Checker
def makeBroadcastTester(op, expected, checks=None, name=None, **kwargs):
def makeBroadcastTester(
op, expected, checks=None, name=None, *, inplace=False, **kwargs
):
if checks is None:
checks = {}
if name is None:
......@@ -695,22 +697,20 @@ def makeBroadcastTester(op, expected, checks=None, name=None, **kwargs):
# cases we need to add it manually.
if not name.endswith("Tester"):
name += "Tester"
if "inplace" in kwargs:
if kwargs["inplace"]:
_expected = expected
if not isinstance(_expected, dict):
def expected(*inputs):
return np.array(_expected(*inputs), dtype=inputs[0].dtype)
def inplace_check(inputs, outputs):
# this used to be inputs[0] is output[0]
# I changed it so that it was easier to satisfy by the
# DebugMode
return np.all(inputs[0] == outputs[0])
checks = dict(checks, inplace_check=inplace_check)
del kwargs["inplace"]
if inplace:
_expected = expected
if not isinstance(_expected, dict):
def expected(*inputs):
return np.array(_expected(*inputs), dtype=inputs[0].dtype)
def inplace_check(inputs, outputs):
# this used to be inputs[0] is output[0]
# I changed it so that it was easier to satisfy by the
# DebugMode
return np.all(inputs[0] == outputs[0])
checks = dict(checks, inplace_check=inplace_check)
return makeTester(name, op, expected, checks, **kwargs)
......@@ -815,6 +815,7 @@ _good_broadcast_unary_normal_no_complex = dict(
big_scalar=[np.arange(17.0, 29.0, 0.5, dtype=config.floatX)],
)
# FIXME: Why is this empty?
_bad_build_broadcast_binary_normal = dict()
_bad_runtime_broadcast_binary_normal = dict(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论