提交 bfabe261 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename inv to reciprocal

上级 9ac6149c
......@@ -1543,7 +1543,7 @@ class ProfileStats:
aes.Cast,
aes.Sgn,
aes.Neg,
aes.Inv,
aes.Reciprocal,
aes.Sqr,
]
scalar_op_amdlibm_speed_up = [
......
......@@ -3385,7 +3385,7 @@ def dnn_batch_normalization_train(
axes = 0 if mode == 'per-activation' else (0, 2, 3)
mean = inputs.mean(axes, keepdims=True)
var = inputs.var(axes, keepdims=True)
invstd = aet.inv(aet.sqrt(var + epsilon))
invstd = aet.reciprocal(aet.sqrt(var + epsilon))
out = (inputs - mean) * gamma * invstd + beta
m = aet.cast(aet.prod(inputs.shape) / aet.prod(mean.shape), 'float32')
......
......@@ -2846,11 +2846,8 @@ pprint.assign(pow, printing.OperatorPrinter("**", 1, "right"))
pprint.assign(mod, printing.OperatorPrinter("%", -1, "left"))
class Inv(UnaryScalarOp):
"""
Multiplicative inverse. Also called reciprocal.
"""
class Reciprocal(UnaryScalarOp):
"""Multiplicative inverse."""
nfunc_spec = ("reciprocal", 1, 1)
......@@ -2878,7 +2875,11 @@ class Inv(UnaryScalarOp):
return f"{z} = 1.0 / {x};"
inv = Inv(upgrade_to_float, name="inv")
reciprocal = Reciprocal(upgrade_to_float, name="reciprocal")
# These are deprecated and will be removed
Inv = Reciprocal
inv = reciprocal
class Log(UnaryScalarOp):
......
......@@ -79,7 +79,7 @@ def neg_inplace(a):
@scalar_elemwise
def inv_inplace(a):
def reciprocal_inplace(a):
"""1.0/a (inplace on a)"""
......
......@@ -1065,13 +1065,15 @@ def neg(a):
"""-a"""
# numpy.reciprocal does integer division on integer inputs
# (which is not very interesting)
@scalar_elemwise
def inv(a):
def reciprocal(a):
"""1.0/a"""
# This is deprecated and will be removed
inv = reciprocal
@scalar_elemwise
def log(a):
"""base e logarithm of a"""
......@@ -2789,6 +2791,7 @@ __all__ = [
"exp2",
"expm1",
"neg",
"reciprocal",
"inv",
"log",
"log2",
......
......@@ -74,7 +74,6 @@ from aesara.tensor.math import (
expm1,
ge,
int_div,
inv,
log,
log1p,
makeKeepDims,
......@@ -82,7 +81,7 @@ from aesara.tensor.math import (
from aesara.tensor.math import max as aet_max
from aesara.tensor.math import maximum, mul, neg
from aesara.tensor.math import pow as aet_pow
from aesara.tensor.math import prod, sgn, sigmoid, softplus, sqr, sqrt, sub
from aesara.tensor.math import prod, reciprocal, sgn, sigmoid, softplus, sqr, sqrt, sub
from aesara.tensor.math import sum as aet_sum
from aesara.tensor.math import true_div
from aesara.tensor.shape import Shape, Shape_i
......@@ -206,7 +205,7 @@ def local_func_inv(fgraph, node):
(aes.Sinh, aes.ArcSinh),
(aes.Conj, aes.Conj),
(aes.Neg, aes.Neg),
(aes.Inv, aes.Inv),
(aes.Reciprocal, aes.Reciprocal),
)
x = node.inputs[0]
......@@ -511,28 +510,29 @@ def local_div_switch_sink(fgraph, node):
class AlgebraicCanonizer(LocalOptimizer):
r"""
Simplification tool. The variable is a local_optimizer. It is best used
with a TopoOptimizer in in_to_out order.
r"""Simplification tool.
Usage: AlgebraicCanonizer(main, inverse, reciprocal, calculate)
The variable is a ``local_optimizer``. It is best used
with a ``TopoOptimizer`` in ``in_to_out`` order.
Usage: ``AlgebraicCanonizer(main, inverse, reciprocal, calculate)``
Parameters
----------
main
A suitable Op class that is commutative, associative and
A suitable ``Op`` class that is commutative, associative and
takes one to an arbitrary number of inputs, e.g. add or
mul
inverse
An Op class such that inverse(main(x, y), y) == x
e.g. sub or true_div
An ``Op`` class such that ``inverse(main(x, y), y) == x``
e.g. ``sub`` or true_div
reciprocal
A function such that main(x, reciprocal(y)) == inverse(x, y)
e.g. neg or inv
A function such that ``main(x, reciprocal(y)) == inverse(x, y)``
e.g. ``neg`` or ``reciprocal``
calculate
Function that takes a list of numpy.ndarray instances
for the numerator, another list for the denumerator,
and calculates inverse(main(\*num), main(\*denum)). It
and calculates ``inverse(main(\*num), main(\*denum))``. It
takes a keyword argument, aslist. If True, the value
should be returned as a list of one element, unless
the value is such that value = main(). In that case,
......@@ -547,7 +547,7 @@ class AlgebraicCanonizer(LocalOptimizer):
>>> mul_canonizer = AlgebraicCanonizer(mul, true_div, inv, \\
... lambda n, d: prod(n) / prod(d))
Examples of optimizations mul_canonizer can perform:
Examples of optimizations ``mul_canonizer`` can perform:
| x / x -> 1
| (x * y) / x -> y
......@@ -562,10 +562,10 @@ class AlgebraicCanonizer(LocalOptimizer):
"""
def __init__(self, main, inverse, reciprocal, calculate, use_reciprocal=True):
def __init__(self, main, inverse_fn, reciprocal_fn, calculate, use_reciprocal=True):
self.main = main
self.inverse = inverse
self.reciprocal = reciprocal
self.inverse = inverse_fn
self.reciprocal = reciprocal_fn
self.calculate = calculate
self.use_reciprocal = use_reciprocal
......@@ -579,11 +579,11 @@ class AlgebraicCanonizer(LocalOptimizer):
def get_num_denum(self, input):
r"""
This extract two lists, num and denum, such that the input is:
self.inverse(self.main(\*num), self.main(\*denum)). It returns
the two lists in a (num, denum) pair.
This extract two lists, ``num`` and ``denum``, such that the input is:
``self.inverse(self.main(\*num), self.main(\*denum))``. It returns
the two lists in a ``(num, denum)`` pair.
For example, for main, inverse and reciprocal = \*, / and inv(),
For example, for main, inverse and ``reciprocal = \*, / and inv()``,
| input -> returned value (num, denum)
......@@ -1013,7 +1013,9 @@ def mul_calculate(num, denum, aslist=False, out_type=None):
return v
local_mul_canonizer = AlgebraicCanonizer(mul, true_div, inv, mul_calculate, False)
local_mul_canonizer = AlgebraicCanonizer(
mul, true_div, reciprocal, mul_calculate, False
)
register_canonicalize(local_mul_canonizer, name="local_mul_canonizer")
......@@ -1847,12 +1849,12 @@ register_canonicalize(local_mul_zero)
@local_optimizer([true_div])
def local_div_to_inv(fgraph, node):
def local_div_to_reciprocal(fgraph, node):
if node.op == true_div and np.all(
local_mul_canonizer.get_constant(node.inputs[0]) == 1.0
):
out = node.outputs[0]
new_out = inv(local_mul_canonizer.merge_num_denum(node.inputs[1:], []))
new_out = reciprocal(local_mul_canonizer.merge_num_denum(node.inputs[1:], []))
# The ones could have forced upcasting
if new_out.dtype != out.dtype:
new_out = cast(new_out, dtype=out.dtype)
......@@ -1864,18 +1866,19 @@ def local_div_to_inv(fgraph, node):
return False
register_specialize(local_div_to_inv)
# TODO: Add this to the canonicalization to reduce redundancy.
register_specialize(local_div_to_reciprocal)
@local_optimizer([inv])
def local_inv_canon(fgraph, node):
if node.op == inv:
@local_optimizer([reciprocal])
def local_reciprocal_canon(fgraph, node):
if node.op == reciprocal:
return [aet_pow(node.inputs[0], -1.0)]
else:
return False
register_canonicalize(local_inv_canon)
register_canonicalize(local_reciprocal_canon)
@local_optimizer([aet_pow])
......@@ -1958,11 +1961,11 @@ def local_pow_specialize(fgraph, node):
if np.all(y == 0.5):
rval = [sqrt(xsym)]
if np.all(y == -0.5):
rval = [inv(sqrt(xsym))]
rval = [reciprocal(sqrt(xsym))]
if np.all(y == -1):
rval = [inv(xsym)]
rval = [reciprocal(xsym)]
if np.all(y == -2):
rval = [inv(sqr(xsym))]
rval = [reciprocal(sqr(xsym))]
if rval:
rval[0] = cast(rval[0], odtype)
assert rval[0].type == node.outputs[0].type, (rval, node.outputs)
......@@ -2032,7 +2035,7 @@ def local_pow_specialize_device(fgraph, node):
aes.Composite([pow2_scal[0]], [rval1_scal])
).make_node(xsym)
if y < 0:
rval = [inv(rval1)]
rval = [reciprocal(rval1)]
else:
rval = [rval1]
if rval:
......@@ -2476,7 +2479,7 @@ def attempt_distribution(factor, num, denum, out_type):
@register_canonicalize
@register_stabilize
@local_optimizer([mul, true_div, inv])
@local_optimizer([mul, true_div, reciprocal])
def local_greedy_distributor(fgraph, node):
"""
Optimize by reducing the number of multiplications and/or divisions.
......@@ -3584,19 +3587,21 @@ def local_sigm_times_exp(fgraph, node):
@register_stabilize
@local_optimizer([inv])
def local_inv_1_plus_exp(fgraph, node):
"""
1/(1+exp(x)) -> sigm(-x)
@local_optimizer([reciprocal])
def local_reciprocal_1_plus_exp(fgraph, node):
"""``reciprocal(1+exp(x)) -> sigm(-x)``
TODO: This is redundant; we can just decided on *one* canonical form
for division (e.g. either `true_div` or `reciprocal`) and have this
taken care of with existing rewrites.
"""
# this optimization should be done for numerical stability
# so we don't care to check client counts
if node.op == inv:
inv_arg = node.inputs[0]
if inv_arg.owner and inv_arg.owner.op == add:
if node.op == reciprocal:
reciprocal_arg = node.inputs[0]
if reciprocal_arg.owner and reciprocal_arg.owner.op == add:
scalars_, scalar_inputs, nonconsts = scalarconsts_rest(
inv_arg.owner.inputs, only_process_constants=True
reciprocal_arg.owner.inputs, only_process_constants=True
)
# scalar_inputs are potentially dimshuffled and fill'd scalars
if len(nonconsts) == 1:
......@@ -3608,9 +3613,11 @@ def local_inv_1_plus_exp(fgraph, node):
)
# keep combined stack traces of
# exp(x): nonconsts[0],
# 1 + exp(x): inv_arg,
# 1 + exp(x): reciprocal_arg,
# 1 / (1 + exp(x)): node.outputs[0]
copy_stack_trace([nonconsts[0], inv_arg, node.outputs[0]], out)
copy_stack_trace(
[nonconsts[0], reciprocal_arg, node.outputs[0]], out
)
return out
......
......@@ -10,7 +10,7 @@ from aesara.tensor import basic as aet
from aesara.tensor.basic import as_tensor_variable
from aesara.tensor.basic_opt import register_specialize_device
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.math import inv, mean, prod, sqrt
from aesara.tensor.math import mean, prod, reciprocal, sqrt
from aesara.tensor.math import sum as aet_sum
from aesara.tensor.type import TensorType
......@@ -185,7 +185,7 @@ def batch_normalization_train(
axes = (0,) + tuple(range(2, inputs.ndim))
mean = inputs.mean(axes, keepdims=True)
var = inputs.var(axes, keepdims=True)
invstd = aet.inv(aet.sqrt(var + epsilon))
invstd = aet.reciprocal(aet.sqrt(var + epsilon))
out = (inputs - mean) * gamma * invstd + beta
m = aet.cast(ate.prod(inputs.shape) / aet.prod(mean.shape), 'float32')
......@@ -802,7 +802,7 @@ def local_abstract_batch_norm_train(fgraph, node):
# The epsilon should not upcast the dtype.
if var.dtype == "float32" and epsilon.dtype == "float64":
epsilon = epsilon.astype("float32")
invstd = inv(sqrt(var + epsilon))
invstd = reciprocal(sqrt(var + epsilon))
out = (x - mean) * (scale * invstd) + bias
results = [out, mean, invstd]
......
......@@ -1435,7 +1435,7 @@ Mathematical
Returns a variable representing the negation of `a` (also ``-a``).
.. function:: inv(a)
.. function:: reciprocal(a)
Returns a variable representing the inverse of a, ie 1.0/a. Also called reciprocal.
......
......@@ -205,7 +205,7 @@ Optimization o4 o3 o2
recognize them. Some examples include:
* ``pow(x,2)`` -> ``x**2``
* ``pow(x,0)`` -> ``ones_like(x)``
* ``pow(x, -0.5)`` -> ``inv(sqrt(x))``
* ``pow(x, -0.5)`` -> ``reciprocal(sqrt(x))``
See :func:`local_pow_specialize`
......
......@@ -22,13 +22,13 @@ from aesara.tensor.math import (
clip,
dot,
floor,
inv,
log,
max_and_argmax,
mean,
minimum,
mod,
prod,
reciprocal,
sqrt,
)
from aesara.tensor.math import sum as aet_sum
......@@ -1879,7 +1879,7 @@ def test_dnn_batchnorm_train():
axes = (0,) + tuple(range(2, ndim))
x_mean_ref = x.mean(axis=axes, keepdims=True)
x_var_ref = x.var(axis=axes, keepdims=True)
x_invstd_ref = inv(sqrt(x_var_ref + eps))
x_invstd_ref = reciprocal(sqrt(x_var_ref + eps))
scale_ref = aet.addbroadcast(scale, *axes)
bias_ref = aet.addbroadcast(bias, *axes)
m = aet.cast(prod(x.shape) / prod(scale.shape), aesara.config.floatX)
......
......@@ -46,7 +46,6 @@ from aesara.scalar.basic import (
int8,
int32,
ints,
inv,
invert,
log,
log1p,
......@@ -55,6 +54,7 @@ from aesara.scalar.basic import (
mul,
neq,
rad2deg,
reciprocal,
sin,
sinh,
sqrt,
......@@ -281,7 +281,7 @@ class TestUpgradeToFloat:
# at least float32, not float16.
unary_ops_vals = [
(inv, list(range(-127, 0)) + list(range(1, 127))),
(reciprocal, list(range(-127, 0)) + list(range(1, 127))),
(sqrt, list(range(0, 128))),
(log, list(range(1, 128))),
(log2, list(range(1, 128))),
......
......@@ -219,7 +219,7 @@ def test_batch_normalization_train():
axes2 = axes
x_mean2 = x.mean(axis=axes2, keepdims=True)
x_var2 = x.var(axis=axes2, keepdims=True)
x_invstd2 = aet.inv(aet.sqrt(x_var2 + eps))
x_invstd2 = aet.reciprocal(aet.sqrt(x_var2 + eps))
scale2 = aet.addbroadcast(scale, *axes2)
bias2 = aet.addbroadcast(bias, *axes2)
out2 = (x - x_mean2) * (scale2 * x_invstd2) + bias2
......
......@@ -66,7 +66,6 @@ from aesara.tensor.math import (
ge,
gt,
int_div,
inv,
invert,
iround,
le,
......@@ -81,6 +80,7 @@ from aesara.tensor.math import (
neq,
)
from aesara.tensor.math import pow as aet_pow
from aesara.tensor.math import reciprocal
from aesara.tensor.math import round as aet_round
from aesara.tensor.math import sin, sinh, softplus, sqr, sqrt, sub
from aesara.tensor.math import sum as aet_sum
......@@ -845,7 +845,7 @@ class TestFusion:
"float32",
),
(
fx - fy + inv(fz),
fx - fy + reciprocal(fz),
(fx, fy, fz),
(fxv, fyv, fzv),
1,
......
......@@ -24,7 +24,6 @@ from aesara.tensor.inplace import (
expm1_inplace,
floor_inplace,
int_div_inplace,
inv_inplace,
log1p_inplace,
log2_inplace,
log10_inplace,
......@@ -36,6 +35,7 @@ from aesara.tensor.inplace import (
neg_inplace,
pow_inplace,
rad2deg_inplace,
reciprocal_inplace,
round_half_away_from_zero_inplace,
round_half_to_even_inplace,
sgn_inplace,
......@@ -55,7 +55,7 @@ from tests import unittest_tools as utt
from tests.tensor.utils import (
_bad_build_broadcast_binary_normal,
_bad_runtime_broadcast_binary_normal,
_bad_runtime_inv,
_bad_runtime_reciprocal,
_good_broadcast_binary_arctan2,
_good_broadcast_binary_normal,
_good_broadcast_div_mod_normal_float_inplace,
......@@ -72,7 +72,7 @@ from tests.tensor.utils import (
_good_broadcast_unary_positive_float,
_good_broadcast_unary_tan,
_good_broadcast_unary_wide_float,
_good_inv_inplace,
_good_reciprocal_inplace,
_numpy_true_div,
angle_eps,
check_floatX,
......@@ -142,11 +142,11 @@ TestTrueDivInplaceBroadcast = makeBroadcastTester(
inplace=True,
)
TestInvInplaceBroadcast = makeBroadcastTester(
op=inv_inplace,
TestReciprocalInplaceBroadcast = makeBroadcastTester(
op=reciprocal_inplace,
expected=lambda x: _numpy_true_div(np.int8(1), x),
good=_good_inv_inplace,
bad_runtime=_bad_runtime_inv,
good=_good_reciprocal_inplace,
bad_runtime=_bad_runtime_reciprocal,
grad_rtol=div_grad_rtol,
inplace=True,
)
......
......@@ -67,7 +67,6 @@ from aesara.tensor.math import (
exp2,
expm1,
floor,
inv,
isclose,
isinf,
isnan,
......@@ -90,6 +89,7 @@ from aesara.tensor.math import (
power,
ptp,
rad2deg,
reciprocal,
round_half_away_from_zero,
round_half_to_even,
sgn,
......@@ -136,7 +136,7 @@ from tests import unittest_tools as utt
from tests.tensor.utils import (
_bad_build_broadcast_binary_normal,
_bad_runtime_broadcast_binary_normal,
_bad_runtime_inv,
_bad_runtime_reciprocal,
_eps,
_good_broadcast_binary_arctan2,
_good_broadcast_binary_normal,
......@@ -153,14 +153,14 @@ from tests.tensor.utils import (
_good_broadcast_unary_positive,
_good_broadcast_unary_tan,
_good_broadcast_unary_wide,
_good_inv,
_good_reciprocal,
_grad_broadcast_binary_normal,
_grad_broadcast_pow_normal,
_grad_broadcast_unary_normal,
_grad_broadcast_unary_normal_no_complex,
_grad_broadcast_unary_normal_no_complex_no_corner_case,
_grad_broadcast_unary_normal_noint,
_grad_inv,
_grad_reciprocal,
_numpy_true_div,
angle_eps,
check_floatX,
......@@ -308,11 +308,11 @@ TestTrueDivBroadcast = makeBroadcastTester(
)
TestInvBroadcast = makeBroadcastTester(
op=inv,
op=reciprocal,
expected=lambda x: upcast_int8_nfunc(np.true_divide)(np.int8(1), x),
good=_good_inv,
bad_runtime=_bad_runtime_inv,
grad=_grad_inv,
good=_good_reciprocal,
bad_runtime=_bad_runtime_reciprocal,
grad=_grad_reciprocal,
grad_rtol=div_grad_rtol,
)
......
......@@ -52,7 +52,6 @@ from aesara.tensor.math import (
ge,
gt,
int_div,
inv,
invert,
iround,
le,
......@@ -67,7 +66,7 @@ from aesara.tensor.math import maximum
from aesara.tensor.math import min as aet_min
from aesara.tensor.math import minimum, mul, neg, neq
from aesara.tensor.math import pow as aet_pow
from aesara.tensor.math import prod, rad2deg
from aesara.tensor.math import prod, rad2deg, reciprocal
from aesara.tensor.math import round as aet_round
from aesara.tensor.math import sgn, sigmoid, sin, sinh, sqr, sqrt, sub
from aesara.tensor.math import sum as aet_sum
......@@ -595,9 +594,9 @@ class TestAlgebraicCanonize:
((fv / fy) / fv, [fv, fy], [fvv, fyv], 1, "float32"),
# must broadcast as their is a dimshuffle in the computation
((dx / dv) / dx, [dx, dv], [dxv, dvv], 1, "float64"),
# topo: [Shape_i, Shape_i, Elemwise{inv,no_inplace}(<TensorType(float64, row)>), Alloc]
# topo: [Shape_i, Shape_i, Elemwise{reciprocal,no_inplace}(<TensorType(float64, row)>), Alloc]
((fx / fv) / fx, [fx, fv], [fxv, fvv], 1, "float32"),
# topo: [Shape_i, Shape_i, Elemwise{inv,no_inplace}(<TensorType(float32, row)>), Alloc]
# topo: [Shape_i, Shape_i, Elemwise{reciprocal,no_inplace}(<TensorType(float32, row)>), Alloc]
]
):
f = function(list(sym_inputs), g, mode=mode)
......@@ -609,7 +608,7 @@ class TestAlgebraicCanonize:
assert isinstance(elem[0].op, (Elemwise,))
assert isinstance(
elem[0].op.scalar_op,
(aes.basic.Inv, aes.basic.TrueDiv),
(aes.basic.Reciprocal, aes.basic.TrueDiv),
)
assert out_dtype == out.dtype
......@@ -912,7 +911,7 @@ class TestAlgebraicCanonize:
topo = f.maker.fgraph.toposort()
assert len(topo) == 2
assert isinstance(topo[0].op, (Elemwise,))
assert isinstance(topo[0].op.scalar_op, aes.basic.Inv)
assert isinstance(topo[0].op.scalar_op, aes.basic.Reciprocal)
assert len(topo[0].inputs) == 1
assert out_dtype == out.dtype
......@@ -927,7 +926,7 @@ class TestAlgebraicCanonize:
topo = f.maker.fgraph.toposort()
assert len(topo) == 2
assert isinstance(topo[0].op, (Elemwise,))
assert isinstance(topo[0].op.scalar_op, aes.basic.Inv)
assert isinstance(topo[0].op.scalar_op, aes.basic.Reciprocal)
assert len(topo[0].inputs) == 1
assert out_dtype == out.dtype
......@@ -1545,7 +1544,7 @@ class TestFusion:
"float32",
),
(
fx - fy + inv(fz),
fx - fy + reciprocal(fz),
(fx, fy, fz),
(fxv, fyv, fzv),
1,
......@@ -2360,7 +2359,7 @@ def test_local_pow_specialize():
f = function([v], v ** (-1), mode=mode)
nodes = [node.op for node in f.maker.fgraph.toposort()]
assert nodes == [inv]
assert nodes == [reciprocal]
utt.assert_allclose(f(val_no0), val_no0 ** (-1))
f = function([v], v ** 2, mode=mode)
......@@ -2372,7 +2371,7 @@ def test_local_pow_specialize():
nodes = [node.op for node in f.maker.fgraph.toposort()]
assert len(nodes) == 2
assert nodes[0] == sqr
assert isinstance(nodes[1].scalar_op, aes.basic.Inv)
assert isinstance(nodes[1].scalar_op, aes.basic.Reciprocal)
utt.assert_allclose(f(val_no0), val_no0 ** (-2))
f = function([v], v ** (0.5), mode=mode)
......@@ -2384,7 +2383,7 @@ def test_local_pow_specialize():
nodes = [node.op for node in f.maker.fgraph.toposort()]
assert len(nodes) == 2
assert nodes[0] == sqrt
assert isinstance(nodes[1].scalar_op, aes.basic.Inv)
assert isinstance(nodes[1].scalar_op, aes.basic.Reciprocal)
utt.assert_allclose(f(val_no0), val_no0 ** (-0.5))
......@@ -2410,7 +2409,7 @@ def test_local_pow_specialize_device_more_aggressive_on_cpu():
assert len(nodes) == 2
assert len(f.maker.fgraph.toposort()[0].op.scalar_op.fgraph.apply_nodes) == 6
assert isinstance(nodes[0].scalar_op, aes.Composite)
assert isinstance(nodes[-1].scalar_op, aes.basic.Inv)
assert isinstance(nodes[-1].scalar_op, aes.basic.Reciprocal)
utt.assert_allclose(f(val_no0), val_no0 ** (-15))
f = function([v], v ** (16), mode=mode)
......@@ -2425,7 +2424,7 @@ def test_local_pow_specialize_device_more_aggressive_on_cpu():
assert len(nodes) == 2
assert len(f.maker.fgraph.toposort()[0].op.scalar_op.fgraph.apply_nodes) == 4
assert isinstance(nodes[0].scalar_op, aes.Composite)
assert isinstance(nodes[-1].scalar_op, aes.basic.Inv)
assert isinstance(nodes[-1].scalar_op, aes.basic.Reciprocal)
utt.assert_allclose(f(val_no0), val_no0 ** (-16))
......@@ -2475,7 +2474,7 @@ class TestFuncInverse:
self.assert_func_pair_optimized(cosh, arccosh, dx)
self.assert_func_pair_optimized(arcsinh, sinh, dx)
self.assert_func_pair_optimized(arctanh, tanh, dx)
self.assert_func_pair_optimized(inv, inv, dx)
self.assert_func_pair_optimized(reciprocal, reciprocal, dx)
self.assert_func_pair_optimized(neg, neg, dx)
cx = dx + complex(0, 1) * (dx + 0.01)
self.assert_func_pair_optimized(conj, conj, cx, is_complex=True)
......@@ -2826,12 +2825,13 @@ class TestLocalErfc:
def test_local_log_erfc(self):
val = [-30, -27, -26, -11, -10, -3, -2, -1, 0, 1, 2, 3, 10, 11, 26, 27, 28, 30]
if config.mode in ["DebugMode", "DEBUG_MODE", "FAST_COMPILE"]:
# python mode don't like the inv(0)
# python mode doesn't like the reciprocal(0)
val.remove(0)
val = np.asarray(val, dtype=config.floatX)
x = vector("x")
# their is some nan that will happear in the graph for the log of the negatives values
# their are some `nan`s that will appear in the graph due to the logs
# of negatives values
mode = copy.copy(self.mode)
mode.check_isfinite = False
mode_fusion = copy.copy(self.mode_fusion)
......@@ -3761,7 +3761,8 @@ def test_local_add_specialize():
assert transformed[0].type == s.type
def test_local_div_to_inv():
def test_local_div_to_reciprocal():
# XXX TODO: This does *not* test `local_div_to_reciprocal`!
num_len_s = lscalar("num_len")
denom_s = scalar("denom")
......
......@@ -840,7 +840,7 @@ _grad_broadcast_binary_normal = dict(
# complex3=(rand(2,3),randcomplex(2,3)),
)
_good_inv = dict(
_good_reciprocal = dict(
normal=[5 * rand_nonzero((2, 3))],
integers=[randint_nonzero(2, 3)],
int8=[np.array(list(range(-127, 0)) + list(range(1, 127)), dtype="int8")],
......@@ -850,14 +850,15 @@ _good_inv = dict(
empty=[np.asarray([], dtype=config.floatX)],
)
_good_inv_inplace = copymod(
_good_inv, without=["integers", "int8", "uint8", "uint16", "complex"]
_good_reciprocal_inplace = copymod(
_good_reciprocal, without=["integers", "int8", "uint8", "uint16", "complex"]
)
_grad_inv = copymod(
_good_inv, without=["integers", "int8", "uint8", "uint16", "complex", "empty"]
_grad_reciprocal = copymod(
_good_reciprocal,
without=["integers", "int8", "uint8", "uint16", "complex", "empty"],
)
_bad_runtime_inv = dict(
_bad_runtime_reciprocal = dict(
float=[np.zeros((2, 3))],
integers=[np.zeros((2, 3), dtype="int64")],
int8=[np.zeros((2, 3), dtype="int8")],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论