提交 6f8fc3b6 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Override default names for inner compiled fgraph functions

上级 b1678fd2
...@@ -68,7 +68,7 @@ def numba_funcify_OpFromGraph(op, node=None, **kwargs): ...@@ -68,7 +68,7 @@ def numba_funcify_OpFromGraph(op, node=None, **kwargs):
output_specs = [Out(o, borrow=False) for o in fgraph.outputs] output_specs = [Out(o, borrow=False) for o in fgraph.outputs]
insert_deepcopy(fgraph, wrapped_inputs=input_specs, wrapped_outputs=output_specs) insert_deepcopy(fgraph, wrapped_inputs=input_specs, wrapped_outputs=output_specs)
fgraph_fn, fgraph_cache_key = numba_funcify_and_cache_key( fgraph_fn, fgraph_cache_key = numba_funcify_and_cache_key(
fgraph, squeeze_output=True, **kwargs fgraph, squeeze_output=True, fgraph_name="numba_ofg", **kwargs
) )
if fgraph_cache_key is None: if fgraph_cache_key is None:
......
...@@ -236,7 +236,7 @@ def numba_funcify_Composite(op, node, **kwargs): ...@@ -236,7 +236,7 @@ def numba_funcify_Composite(op, node, **kwargs):
_ = kwargs.pop("storage_map", None) _ = kwargs.pop("storage_map", None)
composite_fn, fgraph_key = numba_funcify_and_cache_key( composite_fn, fgraph_key = numba_funcify_and_cache_key(
op.fgraph, squeeze_output=True, **kwargs op.fgraph, squeeze_output=True, fgraph_name="numba_composite", **kwargs
) )
if fgraph_key is None: if fgraph_key is None:
composite_key = None composite_key = None
......
...@@ -98,7 +98,9 @@ def numba_funcify_Scan(op: Scan, node, **kwargs): ...@@ -98,7 +98,9 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
output_specs = [Out(x, borrow=False) for x in fgraph.outputs] output_specs = [Out(x, borrow=False) for x in fgraph.outputs]
insert_deepcopy(fgraph, wrapped_inputs=input_specs, wrapped_outputs=output_specs) insert_deepcopy(fgraph, wrapped_inputs=input_specs, wrapped_outputs=output_specs)
scan_inner_func, inner_func_cache_key = numba_funcify_and_cache_key(op.fgraph) scan_inner_func, inner_func_cache_key = numba_funcify_and_cache_key(
op.fgraph, fgraph_name="numba_scan"
)
outer_in_names_to_vars = { outer_in_names_to_vars = {
(f"outer_in_{i}" if i > 0 else "n_steps"): v for i, v in enumerate(node.inputs) (f"outer_in_{i}" if i > 0 else "n_steps"): v for i, v in enumerate(node.inputs)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论