提交 c946160a authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Use second for broadcast_arrays and remove fill_chain helper

上级 74d78256
...@@ -23,7 +23,7 @@ from pytensor.scalar import int32 as int_t ...@@ -23,7 +23,7 @@ from pytensor.scalar import int32 as int_t
from pytensor.scalar import upcast from pytensor.scalar import upcast
from pytensor.tensor import as_tensor_variable from pytensor.tensor import as_tensor_variable
from pytensor.tensor import basic as at from pytensor.tensor import basic as at
from pytensor.tensor import get_vector_length from pytensor.tensor.basic import get_vector_length, second
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import abs as pt_abs from pytensor.tensor.math import abs as pt_abs
from pytensor.tensor.math import all as pt_all from pytensor.tensor.math import all as pt_all
...@@ -1780,7 +1780,19 @@ def broadcast_arrays(*args: TensorVariable) -> Tuple[TensorVariable, ...]: ...@@ -1780,7 +1780,19 @@ def broadcast_arrays(*args: TensorVariable) -> Tuple[TensorVariable, ...]:
The arrays to broadcast. The arrays to broadcast.
""" """
return tuple(broadcast_to(a, broadcast_shape(*args)) for a in args)
def broadcast_with_others(a, others):
for other in others:
a = second(other, a)
return a
brodacasted_vars = []
for i, a in enumerate(args):
# We use indexing and not identity in case there are duplicated variables
others = [a for j, a in enumerate(args) if j != i]
brodacasted_vars.append(broadcast_with_others(a, others))
return brodacasted_vars
__all__ = [ __all__ = [
......
...@@ -38,6 +38,7 @@ from pytensor.tensor.basic import ( ...@@ -38,6 +38,7 @@ from pytensor.tensor.basic import (
) )
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.extra_ops import broadcast_arrays
from pytensor.tensor.math import ( from pytensor.tensor.math import (
All, All,
Any, Any,
...@@ -148,12 +149,6 @@ def get_constant(v): ...@@ -148,12 +149,6 @@ def get_constant(v):
return v return v
def fill_chain(new_out, orig_inputs):
for i in orig_inputs:
new_out = fill(i, new_out)
return [new_out]
@register_canonicalize @register_canonicalize
@register_stabilize @register_stabilize
@node_rewriter([Dot]) @node_rewriter([Dot])
...@@ -1136,7 +1131,7 @@ class AlgebraicCanonizer(NodeRewriter): ...@@ -1136,7 +1131,7 @@ class AlgebraicCanonizer(NodeRewriter):
new = cast(new, out.type.dtype) new = cast(new, out.type.dtype)
if new.type.broadcastable != out.type.broadcastable: if new.type.broadcastable != out.type.broadcastable:
new = fill_chain(new, node.inputs)[0] new = broadcast_arrays(new, *node.inputs)[0]
if (new.type.dtype == out.type.dtype) and ( if (new.type.dtype == out.type.dtype) and (
new.type.broadcastable == out.type.broadcastable new.type.broadcastable == out.type.broadcastable
...@@ -1961,7 +1956,9 @@ def local_mul_zero(fgraph, node): ...@@ -1961,7 +1956,9 @@ def local_mul_zero(fgraph, node):
# print 'MUL by value', value, node.inputs # print 'MUL by value', value, node.inputs
if value == 0: if value == 0:
# print '... returning zeros' # print '... returning zeros'
return fill_chain(_asarray(0, dtype=otype.dtype), node.inputs) return [
broadcast_arrays(_asarray(0, dtype=otype.dtype), *node.inputs)[0]
]
# TODO: Add this to the canonicalization to reduce redundancy. # TODO: Add this to the canonicalization to reduce redundancy.
...@@ -2260,12 +2257,12 @@ def local_add_specialize(fgraph, node): ...@@ -2260,12 +2257,12 @@ def local_add_specialize(fgraph, node):
# Reuse call to constant for cache() # Reuse call to constant for cache()
cst = constant(np.zeros((1,) * ndim, dtype=dtype)) cst = constant(np.zeros((1,) * ndim, dtype=dtype))
assert cst.type.broadcastable == (True,) * ndim assert cst.type.broadcastable == (True,) * ndim
return fill_chain(cst, node.inputs) return [broadcast_arrays(cst, *node.inputs)[0]]
if len(new_inputs) == 1: if len(new_inputs) == 1:
ret = fill_chain(new_inputs[0], node.inputs) ret = [broadcast_arrays(new_inputs[0], *node.inputs)[0]]
else: else:
ret = fill_chain(add(*new_inputs), node.inputs) ret = [broadcast_arrays(add(*new_inputs), *node.inputs)[0]]
# The dtype should not be changed. It can happen if the input # The dtype should not be changed. It can happen if the input
# that was forcing upcasting was equal to 0. # that was forcing upcasting was equal to 0.
...@@ -2383,7 +2380,7 @@ def local_log1p(fgraph, node): ...@@ -2383,7 +2380,7 @@ def local_log1p(fgraph, node):
ninp = nonconsts[0] ninp = nonconsts[0]
if ninp.dtype != log_arg.type.dtype: if ninp.dtype != log_arg.type.dtype:
ninp = ninp.astype(node.outputs[0].dtype) ninp = ninp.astype(node.outputs[0].dtype)
return fill_chain(log1p(ninp), scalar_inputs) return [broadcast_arrays(log1p(ninp), *scalar_inputs)[0]]
elif log_arg.owner and log_arg.owner.op == sub: elif log_arg.owner and log_arg.owner.op == sub:
one = extract_constant(log_arg.owner.inputs[0], only_process_constants=True) one = extract_constant(log_arg.owner.inputs[0], only_process_constants=True)
...@@ -3578,10 +3575,12 @@ def local_reciprocal_1_plus_exp(fgraph, node): ...@@ -3578,10 +3575,12 @@ def local_reciprocal_1_plus_exp(fgraph, node):
if len(nonconsts) == 1: if len(nonconsts) == 1:
if nonconsts[0].owner and nonconsts[0].owner.op == exp: if nonconsts[0].owner and nonconsts[0].owner.op == exp:
if scalars_ and np.allclose(np.sum(scalars_), 1): if scalars_ and np.allclose(np.sum(scalars_), 1):
out = fill_chain( out = [
sigmoid(neg(nonconsts[0].owner.inputs[0])), broadcast_arrays(
scalar_inputs, sigmoid(neg(nonconsts[0].owner.inputs[0])),
) *scalar_inputs,
)[0]
]
# keep combined stack traces of # keep combined stack traces of
# exp(x): nonconsts[0], # exp(x): nonconsts[0],
# 1 + exp(x): reciprocal_arg, # 1 + exp(x): reciprocal_arg,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论