提交 9859c799 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Generalize FunctionGraph conversion function with aesara.link.utils.fgraph_to_python

上级 e6914913
import ast
import re
import warnings import warnings
from collections import Counter
from functools import reduce, singledispatch from functools import reduce, singledispatch
from keyword import iskeyword
from tempfile import NamedTemporaryFile
from textwrap import indent
from types import FunctionType
from warnings import warn from warnings import warn
import jax import jax
...@@ -17,10 +10,9 @@ from numpy.random import RandomState ...@@ -17,10 +10,9 @@ from numpy.random import RandomState
from aesara.compile.ops import DeepCopyOp, ViewOp from aesara.compile.ops import DeepCopyOp, ViewOp
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Constant, Variable
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.ifelse import IfElse from aesara.ifelse import IfElse
from aesara.link.utils import map_storage from aesara.link.utils import fgraph_to_python
from aesara.scalar.basic import Cast, Clip, Composite, Identity, ScalarOp, Second from aesara.scalar.basic import Cast, Clip, Composite, Identity, ScalarOp, Second
from aesara.scan.op import Scan from aesara.scan.op import Scan
from aesara.scan.utils import scan_args as ScanArgs from aesara.scan.utils import scan_args as ScanArgs
...@@ -104,7 +96,7 @@ incsubtensor_ops = (IncSubtensor, AdvancedIncSubtensor1) ...@@ -104,7 +96,7 @@ incsubtensor_ops = (IncSubtensor, AdvancedIncSubtensor1)
@singledispatch @singledispatch
def jax_typify(data, dtype): def jax_typify(data, dtype=None, **kwargs):
"""Convert instances of Aesara `Type`s to JAX types.""" """Convert instances of Aesara `Type`s to JAX types."""
if dtype is None: if dtype is None:
return data return data
...@@ -113,12 +105,12 @@ def jax_typify(data, dtype): ...@@ -113,12 +105,12 @@ def jax_typify(data, dtype):
@jax_typify.register(np.ndarray) @jax_typify.register(np.ndarray)
def jax_typify_ndarray(data, dtype): def jax_typify_ndarray(data, dtype=None, **kwargs):
return jnp.array(data, dtype=dtype) return jnp.array(data, dtype=dtype)
@jax_typify.register(RandomState) @jax_typify.register(RandomState)
def jax_typify_RandomState(state, dtype): def jax_typify_RandomState(state, **kwargs):
state = state.get_state(legacy=False) state = state.get_state(legacy=False)
state["bit_generator"] = numpy_bit_gens[state["bit_generator"]] state["bit_generator"] = numpy_bit_gens[state["bit_generator"]]
return state return state
...@@ -608,92 +600,18 @@ def jax_funcify_FunctionGraph( ...@@ -608,92 +600,18 @@ def jax_funcify_FunctionGraph(
storage_map=None, storage_map=None,
**kwargs, **kwargs,
): ):
return fgraph_to_python(
if order is None: fgraph,
order = fgraph.toposort() jax_funcify,
input_storage, output_storage, storage_map = map_storage( jax_typify,
fgraph, order, input_storage, output_storage, storage_map order,
input_storage,
output_storage,
storage_map,
fgraph_name="jax_funcified_fgraph",
**kwargs,
) )
global_env = {}
fgraph_name = "jax_funcified_fgraph"
def unique_name(x, names_counter=Counter([fgraph_name]), obj_to_names={}):
if x in obj_to_names:
return obj_to_names[x]
if isinstance(x, Variable):
name = re.sub("[^0-9a-zA-Z]+", "_", x.name) if x.name else ""
name = (
name if (name.isidentifier() and not iskeyword(name)) else x.auto_name
)
elif isinstance(x, FunctionType):
name = x.__name__
else:
name = type(x).__name__
name_suffix = names_counter.get(name, "")
local_name = f"{name}{name_suffix}"
names_counter.update((name,))
obj_to_names[x] = local_name
return local_name
body_assigns = []
for node in order:
jax_func = jax_funcify(node.op, node=node, **kwargs)
# Create a local alias with a unique name
local_jax_func_name = unique_name(jax_func)
global_env[local_jax_func_name] = jax_func
node_input_names = []
for i in node.inputs:
local_input_name = unique_name(i)
if storage_map[i][0] is not None or isinstance(i, Constant):
# Constants need to be assigned locally and referenced
global_env[local_input_name] = jax_typify(storage_map[i][0], None)
# TODO: We could attempt to use the storage arrays directly
# E.g. `local_input_name = f"{local_input_name}[0]"`
node_input_names.append(local_input_name)
node_output_names = [unique_name(v) for v in node.outputs]
body_assigns.append(
f"{', '.join(node_output_names)} = {local_jax_func_name}({', '.join(node_input_names)})"
)
fgraph_input_names = [unique_name(v) for v in fgraph.inputs]
fgraph_output_names = [unique_name(v) for v in fgraph.outputs]
joined_body_assigns = indent("\n".join(body_assigns), " ")
if len(fgraph_output_names) == 1:
fgraph_return_src = f"({fgraph_output_names[0]},)"
else:
fgraph_return_src = ", ".join(fgraph_output_names)
fgraph_def_src = f"""
def {fgraph_name}({", ".join(fgraph_input_names)}):
{joined_body_assigns}
return {fgraph_return_src}
"""
fgraph_def_ast = ast.parse(fgraph_def_src)
# Create source code to be (at least temporarily) associated with the
# compiled function (e.g. for easier debugging)
with NamedTemporaryFile(delete=False) as f:
filename = f.name
f.write(fgraph_def_src.encode())
mod_code = compile(fgraph_def_ast, filename, mode="exec")
exec(mod_code, global_env, locals())
fgraph_def = locals()[fgraph_name]
return fgraph_def
@jax_funcify.register(CAReduce) @jax_funcify.register(CAReduce)
def jax_funcify_CAReduce(op, **kwargs): def jax_funcify_CAReduce(op, **kwargs):
......
...@@ -69,7 +69,9 @@ class JAXLinker(PerformLinker): ...@@ -69,7 +69,9 @@ class JAXLinker(PerformLinker):
for n in self.fgraph.inputs: for n in self.fgraph.inputs:
sinput = storage_map[n] sinput = storage_map[n]
if isinstance(sinput[0], RandomState): if isinstance(sinput[0], RandomState):
new_value = jax_typify(sinput[0], getattr(sinput[0], "dtype", None)) new_value = jax_typify(
sinput[0], dtype=getattr(sinput[0], "dtype", None)
)
# We need to remove the reference-based connection to the # We need to remove the reference-based connection to the
# original `RandomState`/shared variable's storage, because # original `RandomState`/shared variable's storage, because
# subsequent attempts to use the same shared variable within # subsequent attempts to use the same shared variable within
......
import ast
import io import io
import re
import sys import sys
import traceback import traceback
import warnings import warnings
from collections import Counter
from keyword import iskeyword
from operator import itemgetter from operator import itemgetter
from typing import Callable, Dict, Iterable, List, NoReturn, Optional, Tuple, Union from tempfile import NamedTemporaryFile
from textwrap import indent
from types import FunctionType
from typing import Any, Callable, Dict, Iterable, List, NoReturn, Optional, Tuple, Union
import numpy as np import numpy as np
from aesara import utils from aesara import utils
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Apply, Constant from aesara.graph.basic import Apply, Constant, Variable
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
...@@ -564,3 +571,139 @@ def register_thunk_trace_excepthook(handler: io.TextIOWrapper = sys.stdout): ...@@ -564,3 +571,139 @@ def register_thunk_trace_excepthook(handler: io.TextIOWrapper = sys.stdout):
register_thunk_trace_excepthook() register_thunk_trace_excepthook()
def fgraph_to_python(
fgraph: FunctionGraph,
op_conversion_fn: Callable,
type_conversion_fn: Optional[Callable] = lambda x, **kwargs: x,
order: Optional[List[Variable]] = None,
input_storage: Optional[List[Any]] = None,
output_storage: Optional[List[Any]] = None,
storage_map: Optional[Dict[Variable, List[Any]]] = None,
fgraph_name: str = "fgraph_to_python",
global_env: Optional[Dict[Any, Any]] = None,
local_env: Optional[Dict[Any, Any]] = None,
**kwargs,
) -> FunctionType:
"""Convert a ``FunctionGraph`` into a regular Python function.
Parameters
==========
fgraph
The ``FunctionGraph`` to convert.
op_conversion_fn
A callable used to convert nodes inside `fgraph` based on their ``Op``
types. It must have the signature ``(Op, **kwargs)``. One of the
keyword arguments will be ``node``, which provides the ``Apply`` node.
type_conversion_fn
A callable used to convert the values in `storage_map`.
order
The ``order`` argument to ``map_storage``.
input_storage
The ``input_storage`` argument to ``map_storage``.
output_storage
The ``output_storage`` argument to ``map_storage``.
storage_map
The ``storage_map`` argument to ``map_storage``.
fgraph_name
The name used for the resulting function.
global_env
The global environment used when the function is constructed.
The default is an empty ``dict``.
local_env
The local environment used when the function is constructed.
The default is ``locals()``.
**kwargs
The remaining keywords are passed to `python_conversion_fn`
"""
if order is None:
order = fgraph.toposort()
input_storage, output_storage, storage_map = map_storage(
fgraph, order, input_storage, output_storage, storage_map
)
if global_env is None:
global_env = {}
def unique_name(x, names_counter=Counter([fgraph_name]), obj_to_names={}):
if x in obj_to_names:
return obj_to_names[x]
if isinstance(x, Variable):
name = re.sub("[^0-9a-zA-Z]+", "_", x.name) if x.name else ""
name = (
name if (name.isidentifier() and not iskeyword(name)) else x.auto_name
)
elif isinstance(x, FunctionType):
name = x.__name__
else:
name = type(x).__name__
name_suffix = names_counter.get(name, "")
local_name = f"{name}{name_suffix}"
names_counter.update((name,))
obj_to_names[x] = local_name
return local_name
body_assigns = []
for node in order:
jax_func = op_conversion_fn(node.op, node=node, **kwargs)
# Create a local alias with a unique name
local_jax_func_name = unique_name(jax_func)
global_env[local_jax_func_name] = jax_func
node_input_names = []
for i in node.inputs:
local_input_name = unique_name(i)
if storage_map[i][0] is not None or isinstance(i, Constant):
# Constants need to be assigned locally and referenced
global_env[local_input_name] = type_conversion_fn(
storage_map[i][0], node=None, **kwargs
)
# TODO: We could attempt to use the storage arrays directly
# E.g. `local_input_name = f"{local_input_name}[0]"`
node_input_names.append(local_input_name)
node_output_names = [unique_name(v) for v in node.outputs]
body_assigns.append(
f"{', '.join(node_output_names)} = {local_jax_func_name}({', '.join(node_input_names)})"
)
fgraph_input_names = [unique_name(v) for v in fgraph.inputs]
fgraph_output_names = [unique_name(v) for v in fgraph.outputs]
joined_body_assigns = indent("\n".join(body_assigns), " ")
if len(fgraph_output_names) == 1:
fgraph_return_src = f"({fgraph_output_names[0]},)"
else:
fgraph_return_src = ", ".join(fgraph_output_names)
fgraph_def_src = f"""
def {fgraph_name}({", ".join(fgraph_input_names)}):
{joined_body_assigns}
return {fgraph_return_src}
"""
fgraph_def_ast = ast.parse(fgraph_def_src)
# Create source code to be (at least temporarily) associated with the
# compiled function (e.g. for easier debugging)
with NamedTemporaryFile(delete=False) as f:
filename = f.name
f.write(fgraph_def_src.encode())
if local_env is None:
local_env = locals()
mod_code = compile(fgraph_def_ast, filename, mode="exec")
exec(mod_code, global_env, local_env)
fgraph_def = local_env[fgraph_name]
return fgraph_def
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论