提交 56327779 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Simplify logic with `variadic_add` and `variadic_mul` helpers

上级 cdae9037
......@@ -102,7 +102,7 @@ from pytensor.tensor import basic as ptb
from pytensor.tensor.basic import expand_dims
from pytensor.tensor.blas_headers import blas_header_text, blas_header_version
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import add, mul, neg, sub
from pytensor.tensor.math import add, mul, neg, sub, variadic_add
from pytensor.tensor.shape import shape_padright, specify_broadcastable
from pytensor.tensor.type import DenseTensorType, TensorType, integer_dtypes, tensor
......@@ -1399,11 +1399,7 @@ def _gemm_from_factored_list(fgraph, lst):
item_to_var(input) for k, input in enumerate(lst) if k not in (i, j)
]
add_inputs.extend(gemm_of_sM_list)
if len(add_inputs) > 1:
rval = [add(*add_inputs)]
else:
rval = add_inputs
# print "RETURNING GEMM THING", rval
rval = [variadic_add(*add_inputs)]
return rval, old_dot22
......
......@@ -1429,18 +1429,12 @@ def mean(input, axis=None, dtype=None, op=False, keepdims=False, acc_dtype=None)
else:
shp = cast(shp, "float64")
if axis is None:
axis = list(range(input.ndim))
elif isinstance(axis, int | np.integer):
axis = [axis]
elif isinstance(axis, np.ndarray) and axis.ndim == 0:
axis = [int(axis)]
else:
axis = [int(a) for a in axis]
# This sequential division will possibly be optimized by PyTensor:
for i in axis:
s = true_div(s, shp[i])
reduced_dims = (
shp
if axis is None
else [shp[i] for i in normalize_axis_tuple(axis, input.type.ndim)]
)
s /= variadic_mul(*reduced_dims).astype(shp.dtype)
# This can happen when axis is an empty list/tuple
if s.dtype != shp.dtype and s.dtype in discrete_dtypes:
......@@ -1596,6 +1590,15 @@ def add(a, *other_terms):
# see decorator for function body
def variadic_add(*args):
"""Add that accepts arbitrary number of inputs, including zero or one."""
if not args:
return constant(0)
if len(args) == 1:
return args[0]
return add(*args)
@scalar_elemwise
def sub(a, b):
"""elementwise subtraction"""
......@@ -1608,6 +1611,15 @@ def mul(a, *other_terms):
# see decorator for function body
def variadic_mul(*args):
"""Mul that accepts arbitrary number of inputs, including zero or one."""
if not args:
return constant(1)
if len(args) == 1:
return args[0]
return mul(*args)
@scalar_elemwise
def true_div(a, b):
"""elementwise [true] division (inverse of multiplication)"""
......
......@@ -68,7 +68,7 @@ from pytensor.tensor.basic import (
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.extra_ops import broadcast_arrays
from pytensor.tensor.math import Sum, add, eq
from pytensor.tensor.math import Sum, add, eq, variadic_add
from pytensor.tensor.shape import Shape_i, shape_padleft
from pytensor.tensor.type import DenseTensorType, TensorType
from pytensor.tensor.variable import TensorConstant, TensorVariable
......@@ -939,13 +939,8 @@ def local_sum_make_vector(fgraph, node):
if acc_dtype == "float64" and out_dtype != "float64" and config.floatX != "float64":
return
if len(elements) == 0:
element_sum = zeros(dtype=out_dtype, shape=())
elif len(elements) == 1:
element_sum = cast(elements[0], out_dtype)
else:
element_sum = cast(
add(*[cast(value, acc_dtype) for value in elements]), out_dtype
variadic_add(*[cast(value, acc_dtype) for value in elements]), out_dtype
)
return [element_sum]
......
......@@ -96,7 +96,15 @@ from pytensor.tensor.blas import (
)
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import Dot, _matrix_matrix_matmul, add, mul, neg, sub
from pytensor.tensor.math import (
Dot,
_matrix_matrix_matmul,
add,
mul,
neg,
sub,
variadic_add,
)
from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift
from pytensor.tensor.type import (
DenseTensorType,
......@@ -386,10 +394,7 @@ def _gemm_from_factored_list(fgraph, lst):
item_to_var(input) for k, input in enumerate(lst) if k not in (i, j)
]
add_inputs.extend(gemm_of_sM_list)
if len(add_inputs) > 1:
rval = [add(*add_inputs)]
else:
rval = add_inputs
rval = [variadic_add(*add_inputs)]
# print "RETURNING GEMM THING", rval
return rval, old_dot22
......
......@@ -76,6 +76,8 @@ from pytensor.tensor.math import (
sub,
tri_gamma,
true_div,
variadic_add,
variadic_mul,
)
from pytensor.tensor.math import abs as pt_abs
from pytensor.tensor.math import max as pt_max
......@@ -1270,17 +1272,13 @@ def local_sum_prod_of_mul_or_div(fgraph, node):
if not outer_terms:
return None
elif len(outer_terms) == 1:
[outer_term] = outer_terms
else:
outer_term = mul(*outer_terms)
outer_term = variadic_mul(*outer_terms)
if not inner_terms:
inner_term = None
elif len(inner_terms) == 1:
[inner_term] = inner_terms
else:
inner_term = mul(*inner_terms)
inner_term = variadic_mul(*inner_terms)
else: # true_div
# We only care about removing the denominator out of the reduction
......@@ -2143,10 +2141,7 @@ def local_add_remove_zeros(fgraph, node):
assert cst.type.broadcastable == (True,) * ndim
return [alloc_like(cst, node_output, fgraph)]
if len(new_inputs) == 1:
ret = [alloc_like(new_inputs[0], node_output, fgraph)]
else:
ret = [alloc_like(add(*new_inputs), node_output, fgraph)]
ret = [alloc_like(variadic_add(*new_inputs), node_output, fgraph)]
# The dtype should not be changed. It can happen if the input
# that was forcing upcasting was equal to 0.
......@@ -2257,10 +2252,7 @@ def local_log1p(fgraph, node):
# scalar_inputs are potentially dimshuffled and fill'd scalars
if scalars and np.allclose(np.sum(scalars), 1):
if nonconsts:
if len(nonconsts) > 1:
ninp = add(*nonconsts)
else:
ninp = nonconsts[0]
ninp = variadic_add(*nonconsts)
if ninp.dtype != log_arg.type.dtype:
ninp = ninp.astype(node.outputs[0].dtype)
return [alloc_like(log1p(ninp), node.outputs[0], fgraph)]
......@@ -3084,10 +3076,7 @@ def local_exp_over_1_plus_exp(fgraph, node):
return
# put the new numerator together
new_num = sigmoids + [exp(t) for t in num_exp_x] + num_rest
if len(new_num) == 1:
new_num = new_num[0]
else:
new_num = mul(*new_num)
new_num = variadic_mul(*new_num)
if num_neg ^ denom_neg:
new_num = -new_num
......
......@@ -48,6 +48,7 @@ from pytensor.tensor.math import (
maximum,
minimum,
or_,
variadic_add,
)
from pytensor.tensor.math import all as pt_all
from pytensor.tensor.rewriting.basic import (
......@@ -1241,11 +1242,7 @@ def local_IncSubtensor_serialize(fgraph, node):
new_inputs = [i for i in node.inputs if not movable(i)] + [
mi.owner.inputs[0] for mi in movable_inputs
]
if len(new_inputs) == 0:
new_add = new_inputs[0]
else:
new_add = add(*new_inputs)
new_add = variadic_add(*new_inputs)
# Copy over stacktrace from original output, as an error
# (e.g. an index error) in this add operation should
# correspond to an error in the original add operation.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论