提交 67519be2 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Rename broadcast_like to alloc_like

上级 316ce0ba
...@@ -8,6 +8,7 @@ import numpy as np ...@@ -8,6 +8,7 @@ import numpy as np
import pytensor.scalar.basic as aes import pytensor.scalar.basic as aes
from pytensor import compile from pytensor import compile
from pytensor.compile.ops import ViewOp from pytensor.compile.ops import ViewOp
from pytensor.graph import FunctionGraph
from pytensor.graph.basic import Constant, Variable from pytensor.graph.basic import Constant, Variable
from pytensor.graph.rewriting.basic import ( from pytensor.graph.rewriting.basic import (
NodeRewriter, NodeRewriter,
...@@ -87,13 +88,13 @@ def merge_broadcastables(broadcastables): ...@@ -87,13 +88,13 @@ def merge_broadcastables(broadcastables):
return [all(bcast) for bcast in zip(*broadcastables)] return [all(bcast) for bcast in zip(*broadcastables)]
def broadcast_like(value, template, fgraph, dtype=None): def alloc_like(
""" value: TensorVariable,
Return a Variable with the same shape and dtype as the template, template: TensorVariable,
filled by broadcasting value through it. `value` will be cast as fgraph: FunctionGraph,
necessary. dtype=None,
) -> TensorVariable:
""" """Fill value to the same shape and dtype as the template via alloc."""
value = as_tensor_variable(value) value = as_tensor_variable(value)
if value.type.is_super(template.type): if value.type.is_super(template.type):
return value return value
...@@ -438,7 +439,7 @@ def local_fill_to_alloc(fgraph, node): ...@@ -438,7 +439,7 @@ def local_fill_to_alloc(fgraph, node):
# In this case, we assume that some broadcasting is needed (otherwise # In this case, we assume that some broadcasting is needed (otherwise
# the condition above would've been true), so we replace the `fill` # the condition above would've been true), so we replace the `fill`
# with an `Alloc`. # with an `Alloc`.
o = broadcast_like(values_ref, shape_ref, fgraph, dtype=values_ref.dtype) o = alloc_like(values_ref, shape_ref, fgraph, dtype=values_ref.dtype)
copy_stack_trace(node.outputs[0], o) copy_stack_trace(node.outputs[0], o)
return [o] return [o]
......
...@@ -34,7 +34,7 @@ from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise ...@@ -34,7 +34,7 @@ from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import exp from pytensor.tensor.math import exp
from pytensor.tensor.rewriting.basic import ( from pytensor.tensor.rewriting.basic import (
broadcast_like, alloc_like,
register_canonicalize, register_canonicalize,
register_specialize, register_specialize,
) )
...@@ -1242,7 +1242,7 @@ def local_inline_composite_constants(fgraph, node): ...@@ -1242,7 +1242,7 @@ def local_inline_composite_constants(fgraph, node):
# Some of the inlined constants were broadcasting the output shape # Some of the inlined constants were broadcasting the output shape
if node.outputs[0].type.broadcastable != new_outputs[0].type.broadcastable: if node.outputs[0].type.broadcastable != new_outputs[0].type.broadcastable:
new_outputs = [ new_outputs = [
broadcast_like(new_out, template=node.outputs[0], fgraph=fgraph) alloc_like(new_out, template=node.outputs[0], fgraph=fgraph)
for new_out in new_outputs for new_out in new_outputs
] ]
......
...@@ -84,7 +84,7 @@ from pytensor.tensor.math import ( ...@@ -84,7 +84,7 @@ from pytensor.tensor.math import (
from pytensor.tensor.math import sum as at_sum from pytensor.tensor.math import sum as at_sum
from pytensor.tensor.math import true_div from pytensor.tensor.math import true_div
from pytensor.tensor.rewriting.basic import ( from pytensor.tensor.rewriting.basic import (
broadcast_like, alloc_like,
broadcasted_by, broadcasted_by,
local_fill_sink, local_fill_sink,
register_canonicalize, register_canonicalize,
...@@ -1973,7 +1973,7 @@ def local_div_to_reciprocal(fgraph, node): ...@@ -1973,7 +1973,7 @@ def local_div_to_reciprocal(fgraph, node):
new_out = cast(new_out, dtype=out.dtype) new_out = cast(new_out, dtype=out.dtype)
# The ones could have forced a specific length # The ones could have forced a specific length
if not out.type.is_super(new_out.type): if not out.type.is_super(new_out.type):
new_out = broadcast_like(new_out, out, fgraph) new_out = alloc_like(new_out, out, fgraph)
return [new_out] return [new_out]
else: else:
return False return False
...@@ -1994,9 +1994,9 @@ def local_pow_canonicalize(fgraph, node): ...@@ -1994,9 +1994,9 @@ def local_pow_canonicalize(fgraph, node):
if node.op == at_pow: if node.op == at_pow:
cst = get_constant(node.inputs[1]) cst = get_constant(node.inputs[1])
if cst == 0: if cst == 0:
return [broadcast_like(1, node.outputs[0], fgraph)] return [alloc_like(1, node.outputs[0], fgraph)]
if cst == 1: if cst == 1:
return [broadcast_like(node.inputs[0], node.outputs[0], fgraph)] return [alloc_like(node.inputs[0], node.outputs[0], fgraph)]
else: else:
return False return False
...@@ -2033,7 +2033,7 @@ def local_zero_div(fgraph, node): ...@@ -2033,7 +2033,7 @@ def local_zero_div(fgraph, node):
node.op.scalar_op, (aes.IntDiv, aes.TrueDiv) node.op.scalar_op, (aes.IntDiv, aes.TrueDiv)
): ):
if get_constant(node.inputs[0]) == 0: if get_constant(node.inputs[0]) == 0:
ret = broadcast_like(0, node.outputs[0], fgraph) ret = alloc_like(0, node.outputs[0], fgraph)
ret.tag.values_eq_approx = values_eq_approx_remove_nan ret.tag.values_eq_approx = values_eq_approx_remove_nan
return [ret] return [ret]
...@@ -2184,7 +2184,7 @@ def local_mul_specialize(fgraph, node): ...@@ -2184,7 +2184,7 @@ def local_mul_specialize(fgraph, node):
has_neg ^= True # toggles has_neg ^= True # toggles
elif y == 0.0: elif y == 0.0:
# if we find any zero, we just return right away # if we find any zero, we just return right away
return [broadcast_like(0, node.outputs[0], fgraph)] return [alloc_like(0, node.outputs[0], fgraph)]
else: else:
new_inputs.append(inp) new_inputs.append(inp)
...@@ -2209,14 +2209,14 @@ def local_mul_specialize(fgraph, node): ...@@ -2209,14 +2209,14 @@ def local_mul_specialize(fgraph, node):
new_inputs = [m1] + new_inputs new_inputs = [m1] + new_inputs
rval = mul(*new_inputs) rval = mul(*new_inputs)
return [broadcast_like(rval, node.outputs[0], fgraph)] return [alloc_like(rval, node.outputs[0], fgraph)]
else: else:
# there are no variable inputs to mul # there are no variable inputs to mul
# N.B. this could have been constant-folded... # N.B. this could have been constant-folded...
if has_neg: if has_neg:
return [broadcast_like(-1, node.outputs[0], fgraph)] return [alloc_like(-1, node.outputs[0], fgraph)]
else: else:
return [broadcast_like(1, node.outputs[0], fgraph)] return [alloc_like(1, node.outputs[0], fgraph)]
@register_specialize @register_specialize
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论