提交 1549649f authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Create a generalized JITLinker

上级 9859c799
from abc import ABC, abstractmethod
from copy import copy, deepcopy
from typing import (
TYPE_CHECKING,
......@@ -146,7 +147,7 @@ class Container:
return r
class Linker:
class Linker(ABC):
"""
Base type for all linkers.
......@@ -189,6 +190,7 @@ class Linker:
new._allow_gc = allow_gc
return new
@abstractmethod
def make_thunk(self, **kwargs) -> ThunkType:
"""
This function must return a triplet (function, input_variables,
......@@ -211,9 +213,6 @@ class Linker:
print e.data # 3.0 iff inplace == True (else unknown)
"""
raise NotImplementedError(
f"make_thunk method of {type(self)} is not implemented."
)
@deprecated("Marked for deletion. Only tests use it.")
def make_function(self, unpack_single: bool = True, **kwargs) -> Callable:
......@@ -630,3 +629,171 @@ def WrapLinkerMany(
f(*args)
return WrapLinker(linkers, wrapper)
class JITLinker(PerformLinker):
"""A ``Linker`` that JIT compiles a ``FunctionGraph`` into a single runnable thunk.
The entirety of ``Linker.fgraph`` is converted into a single JIT compiled
thunk that is run by an Aesara ``VM``.
"""
@abstractmethod
def fgraph_convert(
self, fgraph, order, input_storage, output_storage, storage_map, **kwargs
):
"""Convert a ``FunctionGraph`` into a JIT-able function."""
@abstractmethod
def create_thunk_inputs(self, storage_map: Dict[Variable, List[Any]]) -> List[Any]:
"""Pre-process inputs for the generated thunk.
Parameters
==========
storage_map
A ``dict`` mapping ``Variable``s to their storage lists.
Returns
=======
A list of thunk inputs
"""
@abstractmethod
def jit_compile(self, fn: Callable) -> Callable:
"""JIT compile a converted ``FunctionGraph``."""
def create_jitable_thunk(
self, compute_map, order, input_storage, output_storage, storage_map
):
"""Create a thunk for each output of the `Linker`s `FunctionGraph`.
This is differs from the other thunk-making function in that it only
produces thunks for the `FunctionGraph` output nodes.
Parameters
----------
compute_map: dict
The compute map dictionary.
order
input_storage
output_storage
storage_map: dict
The storage map dictionary.
Returns
-------
thunks: list
A tuple containing the thunks.
output_nodes: list and their
A tuple containing the output nodes.
"""
output_nodes = [o.owner for o in self.fgraph.outputs]
converted_fgraph = self.fgraph_convert(
self.fgraph,
order=order,
input_storage=input_storage,
output_storage=output_storage,
storage_map=storage_map,
)
thunk_inputs = self.create_thunk_inputs(storage_map)
thunks = []
thunk_outputs = [storage_map[n] for n in self.fgraph.outputs]
fgraph_jit = self.jit_compile(converted_fgraph)
def thunk(
fgraph=self.fgraph,
fgraph_jit=fgraph_jit,
thunk_inputs=thunk_inputs,
thunk_outputs=thunk_outputs,
):
outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
for o_node, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs):
compute_map[o_node][0] = True
if len(o_storage) > 1:
assert len(o_storage) == len(o_val)
for i, o_sub_val in enumerate(o_val):
o_storage[i] = o_sub_val
else:
o_storage[0] = o_val
return outputs
thunk.inputs = thunk_inputs
thunk.outputs = thunk_outputs
thunk.lazy = False
thunks.append(thunk)
# This is a bit hackish, but we only return one of the output nodes
return thunks, output_nodes[:1]
def make_all(self, input_storage=None, output_storage=None, storage_map=None):
fgraph = self.fgraph
nodes = self.schedule(fgraph)
no_recycling = self.no_recycling
input_storage, output_storage, storage_map = map_storage(
fgraph, nodes, input_storage, output_storage, storage_map
)
compute_map = {}
for k in storage_map:
compute_map[k] = [k.owner is None]
thunks, nodes = self.create_jitable_thunk(
compute_map, nodes, input_storage, output_storage, storage_map
)
computed, last_user = gc_helper(nodes)
if self.allow_gc:
post_thunk_old_storage = []
for node in nodes:
post_thunk_old_storage.append(
[
storage_map[input]
for input in node.inputs
if (input in computed)
and (input not in fgraph.outputs)
and (node == last_user[input])
]
)
else:
post_thunk_old_storage = None
if no_recycling is True:
no_recycling = list(storage_map.values())
no_recycling = difference(no_recycling, input_storage)
else:
no_recycling = [
storage_map[r] for r in no_recycling if r not in fgraph.inputs
]
fn = streamline(
fgraph, thunks, nodes, post_thunk_old_storage, no_recycling=no_recycling
)
fn.allow_gc = self.allow_gc
fn.storage_map = storage_map
return (
fn,
[
Container(input, storage)
for input, storage in zip(fgraph.inputs, input_storage)
],
[
Container(output, storage, readonly=True)
for output, storage in zip(fgraph.outputs, output_storage)
],
thunks,
nodes,
)
from warnings import warn
from numpy.random import RandomState
from aesara.graph.basic import Constant
from aesara.link.basic import Container, PerformLinker
from aesara.link.utils import gc_helper, map_storage, streamline
from aesara.utils import difference
class JAXLinker(PerformLinker):
"""A `Linker` that JIT-compiles NumPy-based operations using JAX.
from aesara.link.basic import JITLinker
Attributes
----------
allow_non_jax: bool
A boolean indicating whether or not an exception is thrown when the
graph cannot be JAX compiled (e.g. the graph has an unsupported operator).
If `allow_non_jax` is `True`, the fallback is currently Python compilation.
"""
class JAXLinker(JITLinker):
"""A `Linker` that JIT-compiles NumPy-based operations using JAX."""
allow_non_jax = False
def create_jax_thunks(
self, compute_map, order, input_storage, output_storage, storage_map
def fgraph_convert(
self, fgraph, order, input_storage, output_storage, storage_map, **kwargs
):
"""Create a thunk for each output of the `Linker`s `FunctionGraph`.
This is differs from the other thunk-making function in that it only
produces thunks for the `FunctionGraph` output nodes.
from aesara.link.jax.dispatch import jax_funcify
Parameters
----------
compute_map: dict
The compute map dictionary.
storage_map: dict
The storage map dictionary.
Returns
-------
thunks: list
A tuple containing the thunks.
output_nodes: list and their
A tuple containing the output nodes.
return jax_funcify(
fgraph, order, input_storage, output_storage, storage_map, **kwargs
)
"""
def jit_compile(self, fn):
import jax
from aesara.link.jax.dispatch import jax_funcify, jax_typify
output_nodes = [o.owner for o in self.fgraph.outputs]
# Create a JAX-compilable function from our `FunctionGraph`
jaxed_fgraph = jax_funcify(
self.fgraph,
input_storage=input_storage,
output_storage=output_storage,
storage_map=storage_map,
)
# I suppose we can consider `Constant`s to be "static" according to
# JAX.
static_argnums = [
n for n, i in enumerate(self.fgraph.inputs) if isinstance(i, Constant)
]
return jax.jit(fn, static_argnums)
def create_thunk_inputs(self, storage_map):
from aesara.link.jax.dispatch import jax_typify
thunk_inputs = []
for n in self.fgraph.inputs:
......@@ -79,121 +43,4 @@ class JAXLinker(PerformLinker):
sinput = [new_value]
thunk_inputs.append(sinput)
thunks = []
thunk_outputs = [storage_map[n] for n in self.fgraph.outputs]
fgraph_jit = jax.jit(jaxed_fgraph, static_argnums)
def thunk(
fgraph=self.fgraph,
fgraph_jit=fgraph_jit,
thunk_inputs=thunk_inputs,
thunk_outputs=thunk_outputs,
):
outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
for o_node, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs):
compute_map[o_node][0] = True
if len(o_storage) > 1:
assert len(o_storage) == len(o_val)
for i, o_sub_val in enumerate(o_val):
o_storage[i] = o_sub_val
else:
o_storage[0] = o_val
return outputs
thunk.inputs = thunk_inputs
thunk.outputs = thunk_outputs
thunk.lazy = False
thunks.append(thunk)
# This is a bit hackish, but we only return one of the output nodes
return thunks, output_nodes[:1]
def make_all(self, input_storage=None, output_storage=None, storage_map=None):
fgraph = self.fgraph
nodes = self.schedule(fgraph)
no_recycling = self.no_recycling
input_storage, output_storage, storage_map = map_storage(
fgraph, nodes, input_storage, output_storage, storage_map
)
compute_map = {}
for k in storage_map:
compute_map[k] = [k.owner is None]
try:
# We need to create thunk functions that will populate the output
# storage arrays with the JAX-computed values.
thunks, nodes = self.create_jax_thunks(
compute_map, nodes, input_storage, output_storage, storage_map
)
except NotImplementedError as e:
if not self.allow_non_jax:
raise
warn(f"JaxLinker could not JAXify graph: {e}")
thunks = []
for node in nodes:
thunk = node.op.make_thunk(
node, storage_map, compute_map, no_recycling, "py"
)
thunk_inputs = [storage_map[v] for v in node.inputs]
thunk_outputs = [storage_map[v] for v in node.outputs]
thunk.inputs = thunk_inputs
thunk.outputs = thunk_outputs
thunks.append(thunk)
computed, last_user = gc_helper(nodes)
if self.allow_gc:
post_thunk_old_storage = []
for node in nodes:
post_thunk_old_storage.append(
[
storage_map[input]
for input in node.inputs
if (input in computed)
and (input not in fgraph.outputs)
and (node == last_user[input])
]
)
else:
post_thunk_old_storage = None
if no_recycling is True:
no_recycling = list(storage_map.values())
no_recycling = difference(no_recycling, input_storage)
else:
no_recycling = [
storage_map[r] for r in no_recycling if r not in fgraph.inputs
]
fn = streamline(
fgraph, thunks, nodes, post_thunk_old_storage, no_recycling=no_recycling
)
fn.allow_gc = self.allow_gc
fn.storage_map = storage_map
return (
fn,
[
Container(input, storage)
for input, storage in zip(fgraph.inputs, input_storage)
],
[
Container(output, storage, readonly=True)
for output, storage in zip(fgraph.outputs, output_storage)
],
thunks,
nodes,
)
return thunk_inputs
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论