提交 46f8227d authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Use more direct imports in rewriting/elemwise.py

上级 d26374cd
...@@ -8,16 +8,15 @@ from heapq import heapify, heappop, heappush ...@@ -8,16 +8,15 @@ from heapq import heapify, heappop, heappush
from operator import or_ from operator import or_
from warnings import warn from warnings import warn
import pytensor.scalar.basic as ps
from pytensor import clone_replace, compile
from pytensor.compile.function.types import Supervisor from pytensor.compile.function.types import Supervisor
from pytensor.compile.mode import get_target_language from pytensor.compile.mode import get_target_language, optdb
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.basic import Apply, Variable from pytensor.graph.basic import Apply, Variable
from pytensor.graph.destroyhandler import DestroyHandler, inplace_candidates from pytensor.graph.destroyhandler import DestroyHandler, inplace_candidates
from pytensor.graph.features import ReplaceValidate from pytensor.graph.features import ReplaceValidate
from pytensor.graph.fg import FunctionGraph, Output from pytensor.graph.fg import FunctionGraph, Output
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.basic import ( from pytensor.graph.rewriting.basic import (
GraphRewriter, GraphRewriter,
copy_stack_trace, copy_stack_trace,
...@@ -30,11 +29,21 @@ from pytensor.graph.rewriting.db import SequenceDB ...@@ -30,11 +29,21 @@ from pytensor.graph.rewriting.db import SequenceDB
from pytensor.graph.rewriting.unify import OpPattern from pytensor.graph.rewriting.unify import OpPattern
from pytensor.graph.traversal import toposort from pytensor.graph.traversal import toposort
from pytensor.graph.utils import InconsistencyError, MethodNotDefined from pytensor.graph.utils import InconsistencyError, MethodNotDefined
from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop from pytensor.scalar import (
from pytensor.tensor.basic import ( Add,
MakeVector, Composite,
constant, Mul,
ScalarOp,
get_scalar_type,
transfer_type,
upcast_out,
upgrade_to_float,
) )
from pytensor.scalar import cast as scalar_cast
from pytensor.scalar import constant as scalar_constant
from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop
from pytensor.tensor.basic import MakeVector
from pytensor.tensor.basic import constant as tensor_constant
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.math import add, exp, mul from pytensor.tensor.math import add, exp, mul
from pytensor.tensor.rewriting.basic import ( from pytensor.tensor.rewriting.basic import (
...@@ -280,7 +289,7 @@ class InplaceElemwiseOptimizer(InplaceGraphOptimizer): ...@@ -280,7 +289,7 @@ class InplaceElemwiseOptimizer(InplaceGraphOptimizer):
inplace_pattern = {i: o for i, [o] in inplace_pattern.items()} inplace_pattern = {i: o for i, [o] in inplace_pattern.items()}
if hasattr(scalar_op, "make_new_inplace"): if hasattr(scalar_op, "make_new_inplace"):
new_scalar_op = scalar_op.make_new_inplace( new_scalar_op = scalar_op.make_new_inplace(
ps.transfer_type( transfer_type(
*[ *[
inplace_pattern.get(i, o.dtype) inplace_pattern.get(i, o.dtype)
for i, o in enumerate(node.outputs) for i, o in enumerate(node.outputs)
...@@ -289,14 +298,14 @@ class InplaceElemwiseOptimizer(InplaceGraphOptimizer): ...@@ -289,14 +298,14 @@ class InplaceElemwiseOptimizer(InplaceGraphOptimizer):
) )
else: else:
new_scalar_op = type(scalar_op)( new_scalar_op = type(scalar_op)(
ps.transfer_type( transfer_type(
*[inplace_pattern.get(i, None) for i in range(len(node.outputs))] *[inplace_pattern.get(i, None) for i in range(len(node.outputs))]
) )
) )
return type(op)(new_scalar_op, inplace_pattern).make_node(*node.inputs) return type(op)(new_scalar_op, inplace_pattern).make_node(*node.inputs)
compile.optdb.register( optdb.register(
"inplace_elemwise", "inplace_elemwise",
InplaceElemwiseOptimizer(), InplaceElemwiseOptimizer(),
"inplace_elemwise_opt", # for historic reason "inplace_elemwise_opt", # for historic reason
...@@ -428,10 +437,8 @@ def local_useless_dimshuffle_makevector(fgraph, node): ...@@ -428,10 +437,8 @@ def local_useless_dimshuffle_makevector(fgraph, node):
@register_canonicalize @register_canonicalize
@node_rewriter( @node_rewriter(
[ [
elemwise_of( elemwise_of(OpPattern(ScalarOp, output_types_preference=upgrade_to_float)),
OpPattern(ps.ScalarOp, output_types_preference=ps.upgrade_to_float) elemwise_of(OpPattern(ScalarOp, output_types_preference=upcast_out)),
),
elemwise_of(OpPattern(ps.ScalarOp, output_types_preference=ps.upcast_out)),
] ]
) )
def local_upcast_elemwise_constant_inputs(fgraph, node): def local_upcast_elemwise_constant_inputs(fgraph, node):
...@@ -452,7 +459,7 @@ def local_upcast_elemwise_constant_inputs(fgraph, node): ...@@ -452,7 +459,7 @@ def local_upcast_elemwise_constant_inputs(fgraph, node):
changed = False changed = False
for i, inp in enumerate(node.inputs): for i, inp in enumerate(node.inputs):
if inp.type.dtype != output_dtype and isinstance(inp, TensorConstant): if inp.type.dtype != output_dtype and isinstance(inp, TensorConstant):
new_inputs[i] = constant(inp.data.astype(output_dtype)) new_inputs[i] = tensor_constant(inp.data.astype(output_dtype))
changed = True changed = True
if not changed: if not changed:
...@@ -531,7 +538,7 @@ class FusionOptimizer(GraphRewriter): ...@@ -531,7 +538,7 @@ class FusionOptimizer(GraphRewriter):
@staticmethod @staticmethod
def elemwise_to_scalar(inputs, outputs): def elemwise_to_scalar(inputs, outputs):
replacement = { replacement = {
inp: ps.get_scalar_type(inp.type.dtype).make_variable() for inp in inputs inp: get_scalar_type(inp.type.dtype).make_variable() for inp in inputs
} }
for node in toposort(outputs, blockers=inputs): for node in toposort(outputs, blockers=inputs):
scalar_inputs = [replacement[inp] for inp in node.inputs] scalar_inputs = [replacement[inp] for inp in node.inputs]
...@@ -853,7 +860,7 @@ class FusionOptimizer(GraphRewriter): ...@@ -853,7 +860,7 @@ class FusionOptimizer(GraphRewriter):
scalar_inputs, scalar_outputs = self.elemwise_to_scalar(inputs, outputs) scalar_inputs, scalar_outputs = self.elemwise_to_scalar(inputs, outputs)
composite_outputs = Elemwise( composite_outputs = Elemwise(
# No need to clone Composite graph, because `self.elemwise_to_scalar` creates fresh variables # No need to clone Composite graph, because `self.elemwise_to_scalar` creates fresh variables
ps.Composite(scalar_inputs, scalar_outputs, clone_graph=False) Composite(scalar_inputs, scalar_outputs, clone_graph=False)
)(*inputs, return_list=True) )(*inputs, return_list=True)
assert len(outputs) == len(composite_outputs) assert len(outputs) == len(composite_outputs)
for old_out, composite_out in zip(outputs, composite_outputs): for old_out, composite_out in zip(outputs, composite_outputs):
...@@ -913,7 +920,7 @@ class FusionOptimizer(GraphRewriter): ...@@ -913,7 +920,7 @@ class FusionOptimizer(GraphRewriter):
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@node_rewriter([elemwise_of(ps.Composite)]) @node_rewriter([elemwise_of(Composite)])
def local_useless_composite_outputs(fgraph, node): def local_useless_composite_outputs(fgraph, node):
"""Remove inputs and outputs of Composite Ops that are not used anywhere.""" """Remove inputs and outputs of Composite Ops that are not used anywhere."""
comp = node.op.scalar_op comp = node.op.scalar_op
...@@ -934,7 +941,7 @@ def local_useless_composite_outputs(fgraph, node): ...@@ -934,7 +941,7 @@ def local_useless_composite_outputs(fgraph, node):
node.outputs node.outputs
): ):
used_inputs = [node.inputs[i] for i in used_inputs_idxs] used_inputs = [node.inputs[i] for i in used_inputs_idxs]
c = ps.Composite(inputs=used_inner_inputs, outputs=used_inner_outputs) c = Composite(inputs=used_inner_inputs, outputs=used_inner_outputs)
e = Elemwise(scalar_op=c)(*used_inputs, return_list=True) e = Elemwise(scalar_op=c)(*used_inputs, return_list=True)
return dict(zip([node.outputs[i] for i in used_outputs_idxs], e, strict=True)) return dict(zip([node.outputs[i] for i in used_outputs_idxs], e, strict=True))
...@@ -948,7 +955,7 @@ def local_careduce_fusion(fgraph, node): ...@@ -948,7 +955,7 @@ def local_careduce_fusion(fgraph, node):
# FIXME: This check is needed because of the faulty logic in the FIXME below! # FIXME: This check is needed because of the faulty logic in the FIXME below!
# Right now, rewrite only works for `Sum`/`Prod` # Right now, rewrite only works for `Sum`/`Prod`
if not isinstance(car_scalar_op, ps.Add | ps.Mul): if not isinstance(car_scalar_op, Add | Mul):
return None return None
elm_node = car_input.owner elm_node = car_input.owner
...@@ -992,19 +999,19 @@ def local_careduce_fusion(fgraph, node): ...@@ -992,19 +999,19 @@ def local_careduce_fusion(fgraph, node):
car_acc_dtype = node.op.acc_dtype car_acc_dtype = node.op.acc_dtype
scalar_elm_inputs = [ scalar_elm_inputs = [
ps.get_scalar_type(inp.type.dtype).make_variable() for inp in elm_inputs get_scalar_type(inp.type.dtype).make_variable() for inp in elm_inputs
] ]
elm_output = elm_scalar_op(*scalar_elm_inputs) elm_output = elm_scalar_op(*scalar_elm_inputs)
# This input represents the previous value in the `CAReduce` binary reduction # This input represents the previous value in the `CAReduce` binary reduction
carried_car_input = ps.get_scalar_type(car_acc_dtype).make_variable() carried_car_input = get_scalar_type(car_acc_dtype).make_variable()
scalar_fused_output = car_scalar_op(carried_car_input, elm_output) scalar_fused_output = car_scalar_op(carried_car_input, elm_output)
if scalar_fused_output.type.dtype != car_acc_dtype: if scalar_fused_output.type.dtype != car_acc_dtype:
scalar_fused_output = ps.cast(scalar_fused_output, car_acc_dtype) scalar_fused_output = scalar_cast(scalar_fused_output, car_acc_dtype)
fused_scalar_op = ps.Composite( fused_scalar_op = Composite(
inputs=[carried_car_input, *scalar_elm_inputs], outputs=[scalar_fused_output] inputs=[carried_car_input, *scalar_elm_inputs], outputs=[scalar_fused_output]
) )
...@@ -1025,7 +1032,7 @@ def local_careduce_fusion(fgraph, node): ...@@ -1025,7 +1032,7 @@ def local_careduce_fusion(fgraph, node):
return [new_car_op(*elm_inputs)] return [new_car_op(*elm_inputs)]
@node_rewriter([elemwise_of(ps.Composite)]) @node_rewriter([elemwise_of(Composite)])
def local_inline_composite_constants(fgraph, node): def local_inline_composite_constants(fgraph, node):
"""Inline scalar constants in Composite graphs.""" """Inline scalar constants in Composite graphs."""
composite_op = node.op.scalar_op composite_op = node.op.scalar_op
...@@ -1041,7 +1048,7 @@ def local_inline_composite_constants(fgraph, node): ...@@ -1041,7 +1048,7 @@ def local_inline_composite_constants(fgraph, node):
and "complex" not in outer_inp.type.dtype and "complex" not in outer_inp.type.dtype
): ):
if outer_inp.unique_value is not None: if outer_inp.unique_value is not None:
inner_replacements[inner_inp] = ps.constant( inner_replacements[inner_inp] = scalar_constant(
outer_inp.unique_value, dtype=inner_inp.dtype outer_inp.unique_value, dtype=inner_inp.dtype
) )
continue continue
...@@ -1054,7 +1061,7 @@ def local_inline_composite_constants(fgraph, node): ...@@ -1054,7 +1061,7 @@ def local_inline_composite_constants(fgraph, node):
new_inner_outs = clone_replace( new_inner_outs = clone_replace(
composite_op.fgraph.outputs, replace=inner_replacements composite_op.fgraph.outputs, replace=inner_replacements
) )
new_composite_op = ps.Composite(new_inner_inputs, new_inner_outs) new_composite_op = Composite(new_inner_inputs, new_inner_outs)
new_outputs = Elemwise(new_composite_op).make_node(*new_outer_inputs).outputs new_outputs = Elemwise(new_composite_op).make_node(*new_outer_inputs).outputs
# Some of the inlined constants were broadcasting the output shape # Some of the inlined constants were broadcasting the output shape
...@@ -1095,7 +1102,7 @@ def constant_fold_branches_of_add_mul(fgraph, node): ...@@ -1095,7 +1102,7 @@ def constant_fold_branches_of_add_mul(fgraph, node):
if other_inps: if other_inps:
python_op = operator.mul if node.op == mul else operator.add python_op = operator.mul if node.op == mul else operator.add
folded_inputs = [reference_inp, *other_inps] folded_inputs = [reference_inp, *other_inps]
new_inp = constant( new_inp = tensor_constant(
reduce(python_op, (const.data for const in folded_inputs)) reduce(python_op, (const.data for const in folded_inputs))
) )
new_constants = [ new_constants = [
...@@ -1119,7 +1126,7 @@ def constant_fold_branches_of_add_mul(fgraph, node): ...@@ -1119,7 +1126,7 @@ def constant_fold_branches_of_add_mul(fgraph, node):
add_mul_fusion_seqopt = SequenceDB() add_mul_fusion_seqopt = SequenceDB()
compile.optdb.register( optdb.register(
"add_mul_fusion", "add_mul_fusion",
add_mul_fusion_seqopt, add_mul_fusion_seqopt,
"fast_run", "fast_run",
...@@ -1140,7 +1147,7 @@ add_mul_fusion_seqopt.register( ...@@ -1140,7 +1147,7 @@ add_mul_fusion_seqopt.register(
# Register fusion database just before AddDestroyHandler(49.5) (inplace rewrites) # Register fusion database just before AddDestroyHandler(49.5) (inplace rewrites)
fuse_seqopt = SequenceDB() fuse_seqopt = SequenceDB()
compile.optdb.register( optdb.register(
"elemwise_fusion", "elemwise_fusion",
fuse_seqopt, fuse_seqopt,
"fast_run", "fast_run",
...@@ -1271,7 +1278,7 @@ def split_2f1grad_loop(fgraph, node): ...@@ -1271,7 +1278,7 @@ def split_2f1grad_loop(fgraph, node):
return replacements return replacements
compile.optdb["py_only"].register( optdb["py_only"].register(
"split_2f1grad_loop", "split_2f1grad_loop",
split_2f1grad_loop, split_2f1grad_loop,
"fast_compile", "fast_compile",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论