提交 46f8227d authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Use more direct imports in rewriting/elemwise.py

上级 d26374cd
......@@ -8,16 +8,15 @@ from heapq import heapify, heappop, heappush
from operator import or_
from warnings import warn
import pytensor.scalar.basic as ps
from pytensor import clone_replace, compile
from pytensor.compile.function.types import Supervisor
from pytensor.compile.mode import get_target_language
from pytensor.compile.mode import get_target_language, optdb
from pytensor.configdefaults import config
from pytensor.graph.basic import Apply, Variable
from pytensor.graph.destroyhandler import DestroyHandler, inplace_candidates
from pytensor.graph.features import ReplaceValidate
from pytensor.graph.fg import FunctionGraph, Output
from pytensor.graph.op import Op
from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.basic import (
GraphRewriter,
copy_stack_trace,
......@@ -30,11 +29,21 @@ from pytensor.graph.rewriting.db import SequenceDB
from pytensor.graph.rewriting.unify import OpPattern
from pytensor.graph.traversal import toposort
from pytensor.graph.utils import InconsistencyError, MethodNotDefined
from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop
from pytensor.tensor.basic import (
MakeVector,
constant,
from pytensor.scalar import (
Add,
Composite,
Mul,
ScalarOp,
get_scalar_type,
transfer_type,
upcast_out,
upgrade_to_float,
)
from pytensor.scalar import cast as scalar_cast
from pytensor.scalar import constant as scalar_constant
from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop
from pytensor.tensor.basic import MakeVector
from pytensor.tensor.basic import constant as tensor_constant
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.math import add, exp, mul
from pytensor.tensor.rewriting.basic import (
......@@ -280,7 +289,7 @@ class InplaceElemwiseOptimizer(InplaceGraphOptimizer):
inplace_pattern = {i: o for i, [o] in inplace_pattern.items()}
if hasattr(scalar_op, "make_new_inplace"):
new_scalar_op = scalar_op.make_new_inplace(
ps.transfer_type(
transfer_type(
*[
inplace_pattern.get(i, o.dtype)
for i, o in enumerate(node.outputs)
......@@ -289,14 +298,14 @@ class InplaceElemwiseOptimizer(InplaceGraphOptimizer):
)
else:
new_scalar_op = type(scalar_op)(
ps.transfer_type(
transfer_type(
*[inplace_pattern.get(i, None) for i in range(len(node.outputs))]
)
)
return type(op)(new_scalar_op, inplace_pattern).make_node(*node.inputs)
compile.optdb.register(
optdb.register(
"inplace_elemwise",
InplaceElemwiseOptimizer(),
"inplace_elemwise_opt", # for historic reason
......@@ -428,10 +437,8 @@ def local_useless_dimshuffle_makevector(fgraph, node):
@register_canonicalize
@node_rewriter(
[
elemwise_of(
OpPattern(ps.ScalarOp, output_types_preference=ps.upgrade_to_float)
),
elemwise_of(OpPattern(ps.ScalarOp, output_types_preference=ps.upcast_out)),
elemwise_of(OpPattern(ScalarOp, output_types_preference=upgrade_to_float)),
elemwise_of(OpPattern(ScalarOp, output_types_preference=upcast_out)),
]
)
def local_upcast_elemwise_constant_inputs(fgraph, node):
......@@ -452,7 +459,7 @@ def local_upcast_elemwise_constant_inputs(fgraph, node):
changed = False
for i, inp in enumerate(node.inputs):
if inp.type.dtype != output_dtype and isinstance(inp, TensorConstant):
new_inputs[i] = constant(inp.data.astype(output_dtype))
new_inputs[i] = tensor_constant(inp.data.astype(output_dtype))
changed = True
if not changed:
......@@ -531,7 +538,7 @@ class FusionOptimizer(GraphRewriter):
@staticmethod
def elemwise_to_scalar(inputs, outputs):
replacement = {
inp: ps.get_scalar_type(inp.type.dtype).make_variable() for inp in inputs
inp: get_scalar_type(inp.type.dtype).make_variable() for inp in inputs
}
for node in toposort(outputs, blockers=inputs):
scalar_inputs = [replacement[inp] for inp in node.inputs]
......@@ -853,7 +860,7 @@ class FusionOptimizer(GraphRewriter):
scalar_inputs, scalar_outputs = self.elemwise_to_scalar(inputs, outputs)
composite_outputs = Elemwise(
# No need to clone Composite graph, because `self.elemwise_to_scalar` creates fresh variables
ps.Composite(scalar_inputs, scalar_outputs, clone_graph=False)
Composite(scalar_inputs, scalar_outputs, clone_graph=False)
)(*inputs, return_list=True)
assert len(outputs) == len(composite_outputs)
for old_out, composite_out in zip(outputs, composite_outputs):
......@@ -913,7 +920,7 @@ class FusionOptimizer(GraphRewriter):
@register_canonicalize
@register_specialize
@node_rewriter([elemwise_of(ps.Composite)])
@node_rewriter([elemwise_of(Composite)])
def local_useless_composite_outputs(fgraph, node):
"""Remove inputs and outputs of Composite Ops that are not used anywhere."""
comp = node.op.scalar_op
......@@ -934,7 +941,7 @@ def local_useless_composite_outputs(fgraph, node):
node.outputs
):
used_inputs = [node.inputs[i] for i in used_inputs_idxs]
c = ps.Composite(inputs=used_inner_inputs, outputs=used_inner_outputs)
c = Composite(inputs=used_inner_inputs, outputs=used_inner_outputs)
e = Elemwise(scalar_op=c)(*used_inputs, return_list=True)
return dict(zip([node.outputs[i] for i in used_outputs_idxs], e, strict=True))
......@@ -948,7 +955,7 @@ def local_careduce_fusion(fgraph, node):
# FIXME: This check is needed because of the faulty logic in the FIXME below!
# Right now, rewrite only works for `Sum`/`Prod`
if not isinstance(car_scalar_op, ps.Add | ps.Mul):
if not isinstance(car_scalar_op, Add | Mul):
return None
elm_node = car_input.owner
......@@ -992,19 +999,19 @@ def local_careduce_fusion(fgraph, node):
car_acc_dtype = node.op.acc_dtype
scalar_elm_inputs = [
ps.get_scalar_type(inp.type.dtype).make_variable() for inp in elm_inputs
get_scalar_type(inp.type.dtype).make_variable() for inp in elm_inputs
]
elm_output = elm_scalar_op(*scalar_elm_inputs)
# This input represents the previous value in the `CAReduce` binary reduction
carried_car_input = ps.get_scalar_type(car_acc_dtype).make_variable()
carried_car_input = get_scalar_type(car_acc_dtype).make_variable()
scalar_fused_output = car_scalar_op(carried_car_input, elm_output)
if scalar_fused_output.type.dtype != car_acc_dtype:
scalar_fused_output = ps.cast(scalar_fused_output, car_acc_dtype)
scalar_fused_output = scalar_cast(scalar_fused_output, car_acc_dtype)
fused_scalar_op = ps.Composite(
fused_scalar_op = Composite(
inputs=[carried_car_input, *scalar_elm_inputs], outputs=[scalar_fused_output]
)
......@@ -1025,7 +1032,7 @@ def local_careduce_fusion(fgraph, node):
return [new_car_op(*elm_inputs)]
@node_rewriter([elemwise_of(ps.Composite)])
@node_rewriter([elemwise_of(Composite)])
def local_inline_composite_constants(fgraph, node):
"""Inline scalar constants in Composite graphs."""
composite_op = node.op.scalar_op
......@@ -1041,7 +1048,7 @@ def local_inline_composite_constants(fgraph, node):
and "complex" not in outer_inp.type.dtype
):
if outer_inp.unique_value is not None:
inner_replacements[inner_inp] = ps.constant(
inner_replacements[inner_inp] = scalar_constant(
outer_inp.unique_value, dtype=inner_inp.dtype
)
continue
......@@ -1054,7 +1061,7 @@ def local_inline_composite_constants(fgraph, node):
new_inner_outs = clone_replace(
composite_op.fgraph.outputs, replace=inner_replacements
)
new_composite_op = ps.Composite(new_inner_inputs, new_inner_outs)
new_composite_op = Composite(new_inner_inputs, new_inner_outs)
new_outputs = Elemwise(new_composite_op).make_node(*new_outer_inputs).outputs
# Some of the inlined constants were broadcasting the output shape
......@@ -1095,7 +1102,7 @@ def constant_fold_branches_of_add_mul(fgraph, node):
if other_inps:
python_op = operator.mul if node.op == mul else operator.add
folded_inputs = [reference_inp, *other_inps]
new_inp = constant(
new_inp = tensor_constant(
reduce(python_op, (const.data for const in folded_inputs))
)
new_constants = [
......@@ -1119,7 +1126,7 @@ def constant_fold_branches_of_add_mul(fgraph, node):
add_mul_fusion_seqopt = SequenceDB()
compile.optdb.register(
optdb.register(
"add_mul_fusion",
add_mul_fusion_seqopt,
"fast_run",
......@@ -1140,7 +1147,7 @@ add_mul_fusion_seqopt.register(
# Register fusion database just before AddDestroyHandler(49.5) (inplace rewrites)
fuse_seqopt = SequenceDB()
compile.optdb.register(
optdb.register(
"elemwise_fusion",
fuse_seqopt,
"fast_run",
......@@ -1271,7 +1278,7 @@ def split_2f1grad_loop(fgraph, node):
return replacements
compile.optdb["py_only"].register(
optdb["py_only"].register(
"split_2f1grad_loop",
split_2f1grad_loop,
"fast_compile",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论