提交 56327779 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Simplify logic with `variadic_add` and `variadic_mul` helpers

上级 cdae9037
...@@ -102,7 +102,7 @@ from pytensor.tensor import basic as ptb ...@@ -102,7 +102,7 @@ from pytensor.tensor import basic as ptb
from pytensor.tensor.basic import expand_dims from pytensor.tensor.basic import expand_dims
from pytensor.tensor.blas_headers import blas_header_text, blas_header_version from pytensor.tensor.blas_headers import blas_header_text, blas_header_version
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import add, mul, neg, sub from pytensor.tensor.math import add, mul, neg, sub, variadic_add
from pytensor.tensor.shape import shape_padright, specify_broadcastable from pytensor.tensor.shape import shape_padright, specify_broadcastable
from pytensor.tensor.type import DenseTensorType, TensorType, integer_dtypes, tensor from pytensor.tensor.type import DenseTensorType, TensorType, integer_dtypes, tensor
...@@ -1399,11 +1399,7 @@ def _gemm_from_factored_list(fgraph, lst): ...@@ -1399,11 +1399,7 @@ def _gemm_from_factored_list(fgraph, lst):
item_to_var(input) for k, input in enumerate(lst) if k not in (i, j) item_to_var(input) for k, input in enumerate(lst) if k not in (i, j)
] ]
add_inputs.extend(gemm_of_sM_list) add_inputs.extend(gemm_of_sM_list)
if len(add_inputs) > 1: rval = [variadic_add(*add_inputs)]
rval = [add(*add_inputs)]
else:
rval = add_inputs
# print "RETURNING GEMM THING", rval
return rval, old_dot22 return rval, old_dot22
......
...@@ -1429,18 +1429,12 @@ def mean(input, axis=None, dtype=None, op=False, keepdims=False, acc_dtype=None) ...@@ -1429,18 +1429,12 @@ def mean(input, axis=None, dtype=None, op=False, keepdims=False, acc_dtype=None)
else: else:
shp = cast(shp, "float64") shp = cast(shp, "float64")
if axis is None: reduced_dims = (
axis = list(range(input.ndim)) shp
elif isinstance(axis, int | np.integer): if axis is None
axis = [axis] else [shp[i] for i in normalize_axis_tuple(axis, input.type.ndim)]
elif isinstance(axis, np.ndarray) and axis.ndim == 0: )
axis = [int(axis)] s /= variadic_mul(*reduced_dims).astype(shp.dtype)
else:
axis = [int(a) for a in axis]
# This sequential division will possibly be optimized by PyTensor:
for i in axis:
s = true_div(s, shp[i])
# This can happen when axis is an empty list/tuple # This can happen when axis is an empty list/tuple
if s.dtype != shp.dtype and s.dtype in discrete_dtypes: if s.dtype != shp.dtype and s.dtype in discrete_dtypes:
...@@ -1596,6 +1590,15 @@ def add(a, *other_terms): ...@@ -1596,6 +1590,15 @@ def add(a, *other_terms):
# see decorator for function body # see decorator for function body
def variadic_add(*args):
"""Add that accepts arbitrary number of inputs, including zero or one."""
if not args:
return constant(0)
if len(args) == 1:
return args[0]
return add(*args)
@scalar_elemwise @scalar_elemwise
def sub(a, b): def sub(a, b):
"""elementwise subtraction""" """elementwise subtraction"""
...@@ -1608,6 +1611,15 @@ def mul(a, *other_terms): ...@@ -1608,6 +1611,15 @@ def mul(a, *other_terms):
# see decorator for function body # see decorator for function body
def variadic_mul(*args):
"""Mul that accepts arbitrary number of inputs, including zero or one."""
if not args:
return constant(1)
if len(args) == 1:
return args[0]
return mul(*args)
@scalar_elemwise @scalar_elemwise
def true_div(a, b): def true_div(a, b):
"""elementwise [true] division (inverse of multiplication)""" """elementwise [true] division (inverse of multiplication)"""
......
...@@ -68,7 +68,7 @@ from pytensor.tensor.basic import ( ...@@ -68,7 +68,7 @@ from pytensor.tensor.basic import (
from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.extra_ops import broadcast_arrays from pytensor.tensor.extra_ops import broadcast_arrays
from pytensor.tensor.math import Sum, add, eq from pytensor.tensor.math import Sum, add, eq, variadic_add
from pytensor.tensor.shape import Shape_i, shape_padleft from pytensor.tensor.shape import Shape_i, shape_padleft
from pytensor.tensor.type import DenseTensorType, TensorType from pytensor.tensor.type import DenseTensorType, TensorType
from pytensor.tensor.variable import TensorConstant, TensorVariable from pytensor.tensor.variable import TensorConstant, TensorVariable
...@@ -939,14 +939,9 @@ def local_sum_make_vector(fgraph, node): ...@@ -939,14 +939,9 @@ def local_sum_make_vector(fgraph, node):
if acc_dtype == "float64" and out_dtype != "float64" and config.floatX != "float64": if acc_dtype == "float64" and out_dtype != "float64" and config.floatX != "float64":
return return
if len(elements) == 0: element_sum = cast(
element_sum = zeros(dtype=out_dtype, shape=()) variadic_add(*[cast(value, acc_dtype) for value in elements]), out_dtype
elif len(elements) == 1: )
element_sum = cast(elements[0], out_dtype)
else:
element_sum = cast(
add(*[cast(value, acc_dtype) for value in elements]), out_dtype
)
return [element_sum] return [element_sum]
......
...@@ -96,7 +96,15 @@ from pytensor.tensor.blas import ( ...@@ -96,7 +96,15 @@ from pytensor.tensor.blas import (
) )
from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import Dot, _matrix_matrix_matmul, add, mul, neg, sub from pytensor.tensor.math import (
Dot,
_matrix_matrix_matmul,
add,
mul,
neg,
sub,
variadic_add,
)
from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift
from pytensor.tensor.type import ( from pytensor.tensor.type import (
DenseTensorType, DenseTensorType,
...@@ -386,10 +394,7 @@ def _gemm_from_factored_list(fgraph, lst): ...@@ -386,10 +394,7 @@ def _gemm_from_factored_list(fgraph, lst):
item_to_var(input) for k, input in enumerate(lst) if k not in (i, j) item_to_var(input) for k, input in enumerate(lst) if k not in (i, j)
] ]
add_inputs.extend(gemm_of_sM_list) add_inputs.extend(gemm_of_sM_list)
if len(add_inputs) > 1: rval = [variadic_add(*add_inputs)]
rval = [add(*add_inputs)]
else:
rval = add_inputs
# print "RETURNING GEMM THING", rval # print "RETURNING GEMM THING", rval
return rval, old_dot22 return rval, old_dot22
......
...@@ -76,6 +76,8 @@ from pytensor.tensor.math import ( ...@@ -76,6 +76,8 @@ from pytensor.tensor.math import (
sub, sub,
tri_gamma, tri_gamma,
true_div, true_div,
variadic_add,
variadic_mul,
) )
from pytensor.tensor.math import abs as pt_abs from pytensor.tensor.math import abs as pt_abs
from pytensor.tensor.math import max as pt_max from pytensor.tensor.math import max as pt_max
...@@ -1270,17 +1272,13 @@ def local_sum_prod_of_mul_or_div(fgraph, node): ...@@ -1270,17 +1272,13 @@ def local_sum_prod_of_mul_or_div(fgraph, node):
if not outer_terms: if not outer_terms:
return None return None
elif len(outer_terms) == 1:
[outer_term] = outer_terms
else: else:
outer_term = mul(*outer_terms) outer_term = variadic_mul(*outer_terms)
if not inner_terms: if not inner_terms:
inner_term = None inner_term = None
elif len(inner_terms) == 1:
[inner_term] = inner_terms
else: else:
inner_term = mul(*inner_terms) inner_term = variadic_mul(*inner_terms)
else: # true_div else: # true_div
# We only care about removing the denominator out of the reduction # We only care about removing the denominator out of the reduction
...@@ -2143,10 +2141,7 @@ def local_add_remove_zeros(fgraph, node): ...@@ -2143,10 +2141,7 @@ def local_add_remove_zeros(fgraph, node):
assert cst.type.broadcastable == (True,) * ndim assert cst.type.broadcastable == (True,) * ndim
return [alloc_like(cst, node_output, fgraph)] return [alloc_like(cst, node_output, fgraph)]
if len(new_inputs) == 1: ret = [alloc_like(variadic_add(*new_inputs), node_output, fgraph)]
ret = [alloc_like(new_inputs[0], node_output, fgraph)]
else:
ret = [alloc_like(add(*new_inputs), node_output, fgraph)]
# The dtype should not be changed. It can happen if the input # The dtype should not be changed. It can happen if the input
# that was forcing upcasting was equal to 0. # that was forcing upcasting was equal to 0.
...@@ -2257,10 +2252,7 @@ def local_log1p(fgraph, node): ...@@ -2257,10 +2252,7 @@ def local_log1p(fgraph, node):
# scalar_inputs are potentially dimshuffled and fill'd scalars # scalar_inputs are potentially dimshuffled and fill'd scalars
if scalars and np.allclose(np.sum(scalars), 1): if scalars and np.allclose(np.sum(scalars), 1):
if nonconsts: if nonconsts:
if len(nonconsts) > 1: ninp = variadic_add(*nonconsts)
ninp = add(*nonconsts)
else:
ninp = nonconsts[0]
if ninp.dtype != log_arg.type.dtype: if ninp.dtype != log_arg.type.dtype:
ninp = ninp.astype(node.outputs[0].dtype) ninp = ninp.astype(node.outputs[0].dtype)
return [alloc_like(log1p(ninp), node.outputs[0], fgraph)] return [alloc_like(log1p(ninp), node.outputs[0], fgraph)]
...@@ -3084,10 +3076,7 @@ def local_exp_over_1_plus_exp(fgraph, node): ...@@ -3084,10 +3076,7 @@ def local_exp_over_1_plus_exp(fgraph, node):
return return
# put the new numerator together # put the new numerator together
new_num = sigmoids + [exp(t) for t in num_exp_x] + num_rest new_num = sigmoids + [exp(t) for t in num_exp_x] + num_rest
if len(new_num) == 1: new_num = variadic_mul(*new_num)
new_num = new_num[0]
else:
new_num = mul(*new_num)
if num_neg ^ denom_neg: if num_neg ^ denom_neg:
new_num = -new_num new_num = -new_num
......
...@@ -48,6 +48,7 @@ from pytensor.tensor.math import ( ...@@ -48,6 +48,7 @@ from pytensor.tensor.math import (
maximum, maximum,
minimum, minimum,
or_, or_,
variadic_add,
) )
from pytensor.tensor.math import all as pt_all from pytensor.tensor.math import all as pt_all
from pytensor.tensor.rewriting.basic import ( from pytensor.tensor.rewriting.basic import (
...@@ -1241,15 +1242,11 @@ def local_IncSubtensor_serialize(fgraph, node): ...@@ -1241,15 +1242,11 @@ def local_IncSubtensor_serialize(fgraph, node):
new_inputs = [i for i in node.inputs if not movable(i)] + [ new_inputs = [i for i in node.inputs if not movable(i)] + [
mi.owner.inputs[0] for mi in movable_inputs mi.owner.inputs[0] for mi in movable_inputs
] ]
if len(new_inputs) == 0: new_add = variadic_add(*new_inputs)
new_add = new_inputs[0] # Copy over stacktrace from original output, as an error
else: # (e.g. an index error) in this add operation should
new_add = add(*new_inputs) # correspond to an error in the original add operation.
copy_stack_trace(node.outputs[0], new_add)
# Copy over stacktrace from original output, as an error
# (e.g. an index error) in this add operation should
# correspond to an error in the original add operation.
copy_stack_trace(node.outputs[0], new_add)
# stack up the new incsubtensors # stack up the new incsubtensors
tip = new_add tip = new_add
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论