提交 fe2b830f authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Provide the storage map during FunctionGraph conversion calls

This makes the outer `FunctionGraph` storage available for (re)use by inner `FunctionGraphs` (e.g. shared variable values).
上级 5fbaecc3
...@@ -117,7 +117,7 @@ def jax_typify_RandomState(state, **kwargs): ...@@ -117,7 +117,7 @@ def jax_typify_RandomState(state, **kwargs):
@singledispatch @singledispatch
def jax_funcify(op, **kwargs): def jax_funcify(op, node=None, storage_map=None, **kwargs):
"""Create a JAX compatible function from an Aesara `Op`.""" """Create a JAX compatible function from an Aesara `Op`."""
raise NotImplementedError(f"No JAX conversion for the given `Op`: {op}") raise NotImplementedError(f"No JAX conversion for the given `Op`: {op}")
...@@ -594,21 +594,15 @@ def jax_funcify_AdvancedIncSubtensor(op, **kwargs): ...@@ -594,21 +594,15 @@ def jax_funcify_AdvancedIncSubtensor(op, **kwargs):
@jax_funcify.register(FunctionGraph) @jax_funcify.register(FunctionGraph)
def jax_funcify_FunctionGraph( def jax_funcify_FunctionGraph(
fgraph, fgraph,
order=None, node=None,
input_storage=None, fgraph_name="jax_funcified_fgraph",
output_storage=None,
storage_map=None,
**kwargs, **kwargs,
): ):
return fgraph_to_python( return fgraph_to_python(
fgraph, fgraph,
jax_funcify, jax_funcify,
jax_typify, type_conversion_fn=jax_typify,
order, fgraph_name=fgraph_name,
input_storage,
output_storage,
storage_map,
fgraph_name="jax_funcified_fgraph",
**kwargs, **kwargs,
) )
......
...@@ -7,14 +7,10 @@ from aesara.link.basic import JITLinker ...@@ -7,14 +7,10 @@ from aesara.link.basic import JITLinker
class JAXLinker(JITLinker): class JAXLinker(JITLinker):
"""A `Linker` that JIT-compiles NumPy-based operations using JAX.""" """A `Linker` that JIT-compiles NumPy-based operations using JAX."""
def fgraph_convert( def fgraph_convert(self, fgraph, **kwargs):
self, fgraph, order, input_storage, output_storage, storage_map, **kwargs
):
from aesara.link.jax.dispatch import jax_funcify from aesara.link.jax.dispatch import jax_funcify
return jax_funcify( return jax_funcify(fgraph, **kwargs)
fgraph, order, input_storage, output_storage, storage_map, **kwargs
)
def jit_compile(self, fn): def jit_compile(self, fn):
import jax import jax
......
...@@ -60,7 +60,7 @@ def numba_typify(data, dtype=None, **kwargs): ...@@ -60,7 +60,7 @@ def numba_typify(data, dtype=None, **kwargs):
@singledispatch @singledispatch
def numba_funcify(op, **kwargs): def numba_funcify(op, node=None, storage_map=None, **kwargs):
"""Create a Numba compatible function from an Aesara `Op`.""" """Create a Numba compatible function from an Aesara `Op`."""
raise NotImplementedError(f"No Numba conversion for the given `Op`: {op}") raise NotImplementedError(f"No Numba conversion for the given `Op`: {op}")
...@@ -68,27 +68,23 @@ def numba_funcify(op, **kwargs): ...@@ -68,27 +68,23 @@ def numba_funcify(op, **kwargs):
@numba_funcify.register(FunctionGraph) @numba_funcify.register(FunctionGraph)
def numba_funcify_FunctionGraph( def numba_funcify_FunctionGraph(
fgraph, fgraph,
order=None, node=None,
input_storage=None, fgraph_name="jax_funcified_fgraph",
output_storage=None,
storage_map=None,
**kwargs, **kwargs,
): ):
return fgraph_to_python( return fgraph_to_python(
fgraph, fgraph,
numba_funcify, numba_funcify,
numba_typify, type_conversion_fn=numba_typify,
order, fgraph_name=fgraph_name,
input_storage,
output_storage,
storage_map,
fgraph_name="numba_funcified_fgraph",
**kwargs, **kwargs,
) )
@numba_funcify.register(ScalarOp) @numba_funcify.register(ScalarOp)
def numba_funcify_ScalarOp(op, node, **kwargs): def numba_funcify_ScalarOp(op, node, **kwargs):
# TODO: Do we need to cache these functions so that we don't end up
# compiling the same Numba function over and over again?
scalar_func_name = op.nfunc_spec[0] scalar_func_name = op.nfunc_spec[0]
......
...@@ -6,14 +6,10 @@ from aesara.link.basic import JITLinker ...@@ -6,14 +6,10 @@ from aesara.link.basic import JITLinker
class NumbaLinker(JITLinker): class NumbaLinker(JITLinker):
"""A `Linker` that JIT-compiles NumPy-based operations using Numba.""" """A `Linker` that JIT-compiles NumPy-based operations using Numba."""
def fgraph_convert( def fgraph_convert(self, fgraph, **kwargs):
self, fgraph, order, input_storage, output_storage, storage_map, **kwargs
):
from aesara.link.numba.dispatch import numba_funcify from aesara.link.numba.dispatch import numba_funcify
return numba_funcify( return numba_funcify(fgraph, **kwargs)
fgraph, order, input_storage, output_storage, storage_map, **kwargs
)
def jit_compile(self, fn): def jit_compile(self, fn):
jitted_fn = numba.njit(fn) jitted_fn = numba.njit(fn)
......
...@@ -595,6 +595,7 @@ def compile_function_src(src, function_name, global_env=None, local_env=None): ...@@ -595,6 +595,7 @@ def compile_function_src(src, function_name, global_env=None, local_env=None):
def fgraph_to_python( def fgraph_to_python(
fgraph: FunctionGraph, fgraph: FunctionGraph,
op_conversion_fn: Callable, op_conversion_fn: Callable,
*,
type_conversion_fn: Optional[Callable] = lambda x, **kwargs: x, type_conversion_fn: Optional[Callable] = lambda x, **kwargs: x,
order: Optional[List[Variable]] = None, order: Optional[List[Variable]] = None,
input_storage: Optional[List[Any]] = None, input_storage: Optional[List[Any]] = None,
...@@ -613,10 +614,12 @@ def fgraph_to_python( ...@@ -613,10 +614,12 @@ def fgraph_to_python(
The ``FunctionGraph`` to convert. The ``FunctionGraph`` to convert.
op_conversion_fn op_conversion_fn
A callable used to convert nodes inside `fgraph` based on their ``Op`` A callable used to convert nodes inside `fgraph` based on their ``Op``
types. It must have the signature ``(Op, **kwargs)``. One of the types. It must have the signature
keyword arguments will be ``node``, which provides the ``Apply`` node. ``(op: Op, node: Apply=None, storage_map: Dict[Variable, List[Optional[Any]]]=None, **kwargs)``.
type_conversion_fn type_conversion_fn
A callable used to convert the values in `storage_map`. A callable used to convert the values in `storage_map`. It must have
the signature
``(value: Optional[Any], variable: Variable=None, storage: List[Optional[Any]]=None, **kwargs)``.
order order
The ``order`` argument to ``map_storage``. The ``order`` argument to ``map_storage``.
input_storage input_storage
...@@ -670,7 +673,9 @@ def fgraph_to_python( ...@@ -670,7 +673,9 @@ def fgraph_to_python(
body_assigns = [] body_assigns = []
for node in order: for node in order:
jax_func = op_conversion_fn(node.op, node=node, **kwargs) jax_func = op_conversion_fn(
node.op, node=node, storage_map=storage_map, **kwargs
)
# Create a local alias with a unique name # Create a local alias with a unique name
local_jax_func_name = unique_name(jax_func) local_jax_func_name = unique_name(jax_func)
...@@ -682,7 +687,7 @@ def fgraph_to_python( ...@@ -682,7 +687,7 @@ def fgraph_to_python(
if storage_map[i][0] is not None or isinstance(i, Constant): if storage_map[i][0] is not None or isinstance(i, Constant):
# Constants need to be assigned locally and referenced # Constants need to be assigned locally and referenced
global_env[local_input_name] = type_conversion_fn( global_env[local_input_name] = type_conversion_fn(
storage_map[i][0], node=None, **kwargs storage_map[i][0], variable=i, storage=storage_map[i], **kwargs
) )
# TODO: We could attempt to use the storage arrays directly # TODO: We could attempt to use the storage arrays directly
# E.g. `local_input_name = f"{local_input_name}[0]"` # E.g. `local_input_name = f"{local_input_name}[0]"`
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论