提交 30e19e53 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Speedup FusionOptimizer.elemwise_to_scalar

上级 2269b2ec
......@@ -779,9 +779,11 @@ def get_scalar_type(dtype, cache: dict[str, ScalarType] = {}) -> ScalarType:
This caches objects to save allocation and run time.
"""
if dtype not in cache:
cache[dtype] = ScalarType(dtype=dtype)
try:
return cache[dtype]
except KeyError:
cache[dtype] = res = ScalarType(dtype=dtype)
return res
# Register C code for ViewOp on Scalars.
......
......@@ -28,7 +28,7 @@ from pytensor.graph.rewriting.basic import (
)
from pytensor.graph.rewriting.db import SequenceDB
from pytensor.graph.rewriting.unify import OpPattern
from pytensor.graph.traversal import ancestors
from pytensor.graph.traversal import ancestors, toposort
from pytensor.graph.utils import InconsistencyError, MethodNotDefined
from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop
from pytensor.tensor.basic import (
......@@ -530,43 +530,24 @@ class FusionOptimizer(GraphRewriter):
@staticmethod
def elemwise_to_scalar(inputs, outputs):
replace_inputs = [(inp, inp.clone()) for inp in inputs]
outputs = clone_replace(outputs, replace=replace_inputs)
inputs = [inp for _, inp in replace_inputs]
fg = FunctionGraph(inputs=inputs, outputs=outputs, clone=False)
middle_inputs = []
scalar_inputs = [
ps.get_scalar_type(inp.type.dtype).make_variable() for inp in inputs
]
middle_scalar_inputs = []
replacement = {
inp: ps.get_scalar_type(inp.type.dtype).make_variable() for inp in inputs
}
for node in toposort(outputs, blockers=inputs):
scalar_inputs = [replacement[inp] for inp in node.inputs]
replacement.update(
dict(
zip(
node.outputs,
node.op.scalar_op.make_node(*scalar_inputs).outputs,
)
)
)
for node in fg.toposort():
node_scalar_inputs = []
for inp in node.inputs:
if inp in inputs:
node_scalar_inputs.append(scalar_inputs[inputs.index(inp)])
elif inp in middle_inputs:
node_scalar_inputs.append(
middle_scalar_inputs[middle_inputs.index(inp)]
return (
[replacement[inp] for inp in inputs],
[replacement[out] for out in outputs],
)
else:
new_scalar_input = ps.get_scalar_type(
inp.type.dtype
).make_variable()
node_scalar_inputs.append(new_scalar_input)
middle_scalar_inputs.append(new_scalar_input)
middle_inputs.append(inp)
new_scalar_node = node.op.scalar_op.make_node(*node_scalar_inputs)
middle_scalar_inputs.append(new_scalar_node.outputs[0])
middle_inputs.append(node.outputs[0])
scalar_outputs = [
middle_scalar_inputs[middle_inputs.index(out)] for out in fg.outputs
]
return scalar_inputs, scalar_outputs
def apply(self, fgraph):
if fgraph.profile:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论