提交 fe2b830f authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Provide the storage map during FunctionGraph conversion calls

This makes the outer `FunctionGraph` storage available for (re)use by inner `FunctionGraphs` (e.g. shared variable values).
上级 5fbaecc3
......@@ -117,7 +117,7 @@ def jax_typify_RandomState(state, **kwargs):
@singledispatch
def jax_funcify(op, **kwargs):
def jax_funcify(op, node=None, storage_map=None, **kwargs):
"""Create a JAX compatible function from an Aesara `Op`."""
raise NotImplementedError(f"No JAX conversion for the given `Op`: {op}")
......@@ -594,21 +594,15 @@ def jax_funcify_AdvancedIncSubtensor(op, **kwargs):
@jax_funcify.register(FunctionGraph)
def jax_funcify_FunctionGraph(
fgraph,
order=None,
input_storage=None,
output_storage=None,
storage_map=None,
node=None,
fgraph_name="jax_funcified_fgraph",
**kwargs,
):
return fgraph_to_python(
fgraph,
jax_funcify,
jax_typify,
order,
input_storage,
output_storage,
storage_map,
fgraph_name="jax_funcified_fgraph",
type_conversion_fn=jax_typify,
fgraph_name=fgraph_name,
**kwargs,
)
......
......@@ -7,14 +7,10 @@ from aesara.link.basic import JITLinker
class JAXLinker(JITLinker):
"""A `Linker` that JIT-compiles NumPy-based operations using JAX."""
def fgraph_convert(
self, fgraph, order, input_storage, output_storage, storage_map, **kwargs
):
def fgraph_convert(self, fgraph, **kwargs):
from aesara.link.jax.dispatch import jax_funcify
return jax_funcify(
fgraph, order, input_storage, output_storage, storage_map, **kwargs
)
return jax_funcify(fgraph, **kwargs)
def jit_compile(self, fn):
import jax
......
......@@ -60,7 +60,7 @@ def numba_typify(data, dtype=None, **kwargs):
@singledispatch
def numba_funcify(op, **kwargs):
def numba_funcify(op, node=None, storage_map=None, **kwargs):
"""Create a Numba compatible function from an Aesara `Op`."""
raise NotImplementedError(f"No Numba conversion for the given `Op`: {op}")
......@@ -68,27 +68,23 @@ def numba_funcify(op, **kwargs):
@numba_funcify.register(FunctionGraph)
def numba_funcify_FunctionGraph(
fgraph,
order=None,
input_storage=None,
output_storage=None,
storage_map=None,
node=None,
fgraph_name="jax_funcified_fgraph",
**kwargs,
):
return fgraph_to_python(
fgraph,
numba_funcify,
numba_typify,
order,
input_storage,
output_storage,
storage_map,
fgraph_name="numba_funcified_fgraph",
type_conversion_fn=numba_typify,
fgraph_name=fgraph_name,
**kwargs,
)
@numba_funcify.register(ScalarOp)
def numba_funcify_ScalarOp(op, node, **kwargs):
# TODO: Do we need to cache these functions so that we don't end up
# compiling the same Numba function over and over again?
scalar_func_name = op.nfunc_spec[0]
......
......@@ -6,14 +6,10 @@ from aesara.link.basic import JITLinker
class NumbaLinker(JITLinker):
"""A `Linker` that JIT-compiles NumPy-based operations using Numba."""
def fgraph_convert(
self, fgraph, order, input_storage, output_storage, storage_map, **kwargs
):
def fgraph_convert(self, fgraph, **kwargs):
from aesara.link.numba.dispatch import numba_funcify
return numba_funcify(
fgraph, order, input_storage, output_storage, storage_map, **kwargs
)
return numba_funcify(fgraph, **kwargs)
def jit_compile(self, fn):
jitted_fn = numba.njit(fn)
......
......@@ -595,6 +595,7 @@ def compile_function_src(src, function_name, global_env=None, local_env=None):
def fgraph_to_python(
fgraph: FunctionGraph,
op_conversion_fn: Callable,
*,
type_conversion_fn: Optional[Callable] = lambda x, **kwargs: x,
order: Optional[List[Variable]] = None,
input_storage: Optional[List[Any]] = None,
......@@ -613,10 +614,12 @@ def fgraph_to_python(
The ``FunctionGraph`` to convert.
op_conversion_fn
A callable used to convert nodes inside `fgraph` based on their ``Op``
types. It must have the signature ``(Op, **kwargs)``. One of the
keyword arguments will be ``node``, which provides the ``Apply`` node.
types. It must have the signature
``(op: Op, node: Apply=None, storage_map: Dict[Variable, List[Optional[Any]]]=None, **kwargs)``.
type_conversion_fn
A callable used to convert the values in `storage_map`.
A callable used to convert the values in `storage_map`. It must have
the signature
``(value: Optional[Any], variable: Variable=None, storage: List[Optional[Any]]=None, **kwargs)``.
order
The ``order`` argument to ``map_storage``.
input_storage
......@@ -670,7 +673,9 @@ def fgraph_to_python(
body_assigns = []
for node in order:
jax_func = op_conversion_fn(node.op, node=node, **kwargs)
jax_func = op_conversion_fn(
node.op, node=node, storage_map=storage_map, **kwargs
)
# Create a local alias with a unique name
local_jax_func_name = unique_name(jax_func)
......@@ -682,7 +687,7 @@ def fgraph_to_python(
if storage_map[i][0] is not None or isinstance(i, Constant):
# Constants need to be assigned locally and referenced
global_env[local_input_name] = type_conversion_fn(
storage_map[i][0], node=None, **kwargs
storage_map[i][0], variable=i, storage=storage_map[i], **kwargs
)
# TODO: We could attempt to use the storage arrays directly
# E.g. `local_input_name = f"{local_input_name}[0]"`
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论