提交 189ba03a authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Faster vectorize by walking sorted nodes

上级 e2d07514
......@@ -3,7 +3,13 @@ from collections.abc import Iterable, Mapping, Sequence
from functools import partial, singledispatch
from typing import Optional, Union, cast, overload
from pytensor.graph.basic import Apply, Constant, Variable, truncated_graph_inputs
from pytensor.graph.basic import (
Apply,
Constant,
Variable,
io_toposort,
truncated_graph_inputs,
)
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op
......@@ -295,19 +301,14 @@ def vectorize_graph(
inputs = truncated_graph_inputs(seq_outputs, ancestors_to_include=replace.keys())
new_inputs = [replace.get(inp, inp) for inp in inputs]
def transform(var: Variable) -> Variable:
if var in inputs:
return new_inputs[inputs.index(var)]
vect_vars = dict(zip(inputs, new_inputs))
for node in io_toposort(inputs, seq_outputs):
vect_inputs = [vect_vars.get(inp, inp) for inp in node.inputs]
vect_node = vectorize_node(node, *vect_inputs)
for output, vect_output in zip(node.outputs, vect_node.outputs):
vect_vars[output] = vect_output
node = var.owner
batched_inputs = [transform(inp) for inp in node.inputs]
batched_node = vectorize_node(node, *batched_inputs)
batched_var = batched_node.outputs[var.owner.outputs.index(var)]
return cast(Variable, batched_var)
# TODO: MergeOptimization or node caching?
seq_vect_outputs = [transform(out) for out in seq_outputs]
seq_vect_outputs = [vect_vars[out] for out in seq_outputs]
if isinstance(outputs, Sequence):
return seq_vect_outputs
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论