提交 189ba03a authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Faster vectorize by walking sorted nodes

上级 e2d07514
...@@ -3,7 +3,13 @@ from collections.abc import Iterable, Mapping, Sequence ...@@ -3,7 +3,13 @@ from collections.abc import Iterable, Mapping, Sequence
from functools import partial, singledispatch from functools import partial, singledispatch
from typing import Optional, Union, cast, overload from typing import Optional, Union, cast, overload
from pytensor.graph.basic import Apply, Constant, Variable, truncated_graph_inputs from pytensor.graph.basic import (
Apply,
Constant,
Variable,
io_toposort,
truncated_graph_inputs,
)
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op from pytensor.graph.op import Op
...@@ -295,19 +301,14 @@ def vectorize_graph( ...@@ -295,19 +301,14 @@ def vectorize_graph(
inputs = truncated_graph_inputs(seq_outputs, ancestors_to_include=replace.keys()) inputs = truncated_graph_inputs(seq_outputs, ancestors_to_include=replace.keys())
new_inputs = [replace.get(inp, inp) for inp in inputs] new_inputs = [replace.get(inp, inp) for inp in inputs]
def transform(var: Variable) -> Variable: vect_vars = dict(zip(inputs, new_inputs))
if var in inputs: for node in io_toposort(inputs, seq_outputs):
return new_inputs[inputs.index(var)] vect_inputs = [vect_vars.get(inp, inp) for inp in node.inputs]
vect_node = vectorize_node(node, *vect_inputs)
for output, vect_output in zip(node.outputs, vect_node.outputs):
vect_vars[output] = vect_output
node = var.owner seq_vect_outputs = [vect_vars[out] for out in seq_outputs]
batched_inputs = [transform(inp) for inp in node.inputs]
batched_node = vectorize_node(node, *batched_inputs)
batched_var = batched_node.outputs[var.owner.outputs.index(var)]
return cast(Variable, batched_var)
# TODO: MergeOptimization or node caching?
seq_vect_outputs = [transform(out) for out in seq_outputs]
if isinstance(outputs, Sequence): if isinstance(outputs, Sequence):
return seq_vect_outputs return seq_vect_outputs
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论