提交 67519be2 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Rename broadcast_like to alloc_like

上级 316ce0ba
......@@ -8,6 +8,7 @@ import numpy as np
import pytensor.scalar.basic as aes
from pytensor import compile
from pytensor.compile.ops import ViewOp
from pytensor.graph import FunctionGraph
from pytensor.graph.basic import Constant, Variable
from pytensor.graph.rewriting.basic import (
NodeRewriter,
......@@ -87,13 +88,13 @@ def merge_broadcastables(broadcastables):
return [all(bcast) for bcast in zip(*broadcastables)]
def broadcast_like(value, template, fgraph, dtype=None):
"""
Return a Variable with the same shape and dtype as the template,
filled by broadcasting value through it. `value` will be cast as
necessary.
"""
def alloc_like(
value: TensorVariable,
template: TensorVariable,
fgraph: FunctionGraph,
dtype=None,
) -> TensorVariable:
"""Fill value to the same shape and dtype as the template via alloc."""
value = as_tensor_variable(value)
if value.type.is_super(template.type):
return value
......@@ -438,7 +439,7 @@ def local_fill_to_alloc(fgraph, node):
# In this case, we assume that some broadcasting is needed (otherwise
# the condition above would've been true), so we replace the `fill`
# with an `Alloc`.
o = broadcast_like(values_ref, shape_ref, fgraph, dtype=values_ref.dtype)
o = alloc_like(values_ref, shape_ref, fgraph, dtype=values_ref.dtype)
copy_stack_trace(node.outputs[0], o)
return [o]
......
......@@ -34,7 +34,7 @@ from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import exp
from pytensor.tensor.rewriting.basic import (
broadcast_like,
alloc_like,
register_canonicalize,
register_specialize,
)
......@@ -1242,7 +1242,7 @@ def local_inline_composite_constants(fgraph, node):
# Some of the inlined constants were broadcasting the output shape
if node.outputs[0].type.broadcastable != new_outputs[0].type.broadcastable:
new_outputs = [
broadcast_like(new_out, template=node.outputs[0], fgraph=fgraph)
alloc_like(new_out, template=node.outputs[0], fgraph=fgraph)
for new_out in new_outputs
]
......
......@@ -84,7 +84,7 @@ from pytensor.tensor.math import (
from pytensor.tensor.math import sum as at_sum
from pytensor.tensor.math import true_div
from pytensor.tensor.rewriting.basic import (
broadcast_like,
alloc_like,
broadcasted_by,
local_fill_sink,
register_canonicalize,
......@@ -1973,7 +1973,7 @@ def local_div_to_reciprocal(fgraph, node):
new_out = cast(new_out, dtype=out.dtype)
# The ones could have forced a specific length
if not out.type.is_super(new_out.type):
new_out = broadcast_like(new_out, out, fgraph)
new_out = alloc_like(new_out, out, fgraph)
return [new_out]
else:
return False
......@@ -1994,9 +1994,9 @@ def local_pow_canonicalize(fgraph, node):
if node.op == at_pow:
cst = get_constant(node.inputs[1])
if cst == 0:
return [broadcast_like(1, node.outputs[0], fgraph)]
return [alloc_like(1, node.outputs[0], fgraph)]
if cst == 1:
return [broadcast_like(node.inputs[0], node.outputs[0], fgraph)]
return [alloc_like(node.inputs[0], node.outputs[0], fgraph)]
else:
return False
......@@ -2033,7 +2033,7 @@ def local_zero_div(fgraph, node):
node.op.scalar_op, (aes.IntDiv, aes.TrueDiv)
):
if get_constant(node.inputs[0]) == 0:
ret = broadcast_like(0, node.outputs[0], fgraph)
ret = alloc_like(0, node.outputs[0], fgraph)
ret.tag.values_eq_approx = values_eq_approx_remove_nan
return [ret]
......@@ -2184,7 +2184,7 @@ def local_mul_specialize(fgraph, node):
has_neg ^= True # toggles
elif y == 0.0:
# if we find any zero, we just return right away
return [broadcast_like(0, node.outputs[0], fgraph)]
return [alloc_like(0, node.outputs[0], fgraph)]
else:
new_inputs.append(inp)
......@@ -2209,14 +2209,14 @@ def local_mul_specialize(fgraph, node):
new_inputs = [m1] + new_inputs
rval = mul(*new_inputs)
return [broadcast_like(rval, node.outputs[0], fgraph)]
return [alloc_like(rval, node.outputs[0], fgraph)]
else:
# there are no variable inputs to mul
# N.B. this could have been constant-folded...
if has_neg:
return [broadcast_like(-1, node.outputs[0], fgraph)]
return [alloc_like(-1, node.outputs[0], fgraph)]
else:
return [broadcast_like(1, node.outputs[0], fgraph)]
return [alloc_like(1, node.outputs[0], fgraph)]
@register_specialize
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论