提交 1549649f authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Create a generalized JITLinker

上级 9859c799
from abc import ABC, abstractmethod
from copy import copy, deepcopy from copy import copy, deepcopy
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
...@@ -146,7 +147,7 @@ class Container: ...@@ -146,7 +147,7 @@ class Container:
return r return r
class Linker: class Linker(ABC):
""" """
Base type for all linkers. Base type for all linkers.
...@@ -189,6 +190,7 @@ class Linker: ...@@ -189,6 +190,7 @@ class Linker:
new._allow_gc = allow_gc new._allow_gc = allow_gc
return new return new
@abstractmethod
def make_thunk(self, **kwargs) -> ThunkType: def make_thunk(self, **kwargs) -> ThunkType:
""" """
This function must return a triplet (function, input_variables, This function must return a triplet (function, input_variables,
...@@ -211,9 +213,6 @@ class Linker: ...@@ -211,9 +213,6 @@ class Linker:
print e.data # 3.0 iff inplace == True (else unknown) print e.data # 3.0 iff inplace == True (else unknown)
""" """
raise NotImplementedError(
f"make_thunk method of {type(self)} is not implemented."
)
@deprecated("Marked for deletion. Only tests use it.") @deprecated("Marked for deletion. Only tests use it.")
def make_function(self, unpack_single: bool = True, **kwargs) -> Callable: def make_function(self, unpack_single: bool = True, **kwargs) -> Callable:
...@@ -630,3 +629,171 @@ def WrapLinkerMany( ...@@ -630,3 +629,171 @@ def WrapLinkerMany(
f(*args) f(*args)
return WrapLinker(linkers, wrapper) return WrapLinker(linkers, wrapper)
class JITLinker(PerformLinker):
"""A ``Linker`` that JIT compiles a ``FunctionGraph`` into a single runnable thunk.
The entirety of ``Linker.fgraph`` is converted into a single JIT compiled
thunk that is run by an Aesara ``VM``.
"""
@abstractmethod
def fgraph_convert(
self, fgraph, order, input_storage, output_storage, storage_map, **kwargs
):
"""Convert a ``FunctionGraph`` into a JIT-able function."""
@abstractmethod
def create_thunk_inputs(self, storage_map: Dict[Variable, List[Any]]) -> List[Any]:
"""Pre-process inputs for the generated thunk.
Parameters
==========
storage_map
A ``dict`` mapping ``Variable``s to their storage lists.
Returns
=======
A list of thunk inputs
"""
@abstractmethod
def jit_compile(self, fn: Callable) -> Callable:
"""JIT compile a converted ``FunctionGraph``."""
def create_jitable_thunk(
self, compute_map, order, input_storage, output_storage, storage_map
):
"""Create a thunk for each output of the `Linker`s `FunctionGraph`.
This is differs from the other thunk-making function in that it only
produces thunks for the `FunctionGraph` output nodes.
Parameters
----------
compute_map: dict
The compute map dictionary.
order
input_storage
output_storage
storage_map: dict
The storage map dictionary.
Returns
-------
thunks: list
A tuple containing the thunks.
output_nodes: list and their
A tuple containing the output nodes.
"""
output_nodes = [o.owner for o in self.fgraph.outputs]
converted_fgraph = self.fgraph_convert(
self.fgraph,
order=order,
input_storage=input_storage,
output_storage=output_storage,
storage_map=storage_map,
)
thunk_inputs = self.create_thunk_inputs(storage_map)
thunks = []
thunk_outputs = [storage_map[n] for n in self.fgraph.outputs]
fgraph_jit = self.jit_compile(converted_fgraph)
def thunk(
fgraph=self.fgraph,
fgraph_jit=fgraph_jit,
thunk_inputs=thunk_inputs,
thunk_outputs=thunk_outputs,
):
outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
for o_node, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs):
compute_map[o_node][0] = True
if len(o_storage) > 1:
assert len(o_storage) == len(o_val)
for i, o_sub_val in enumerate(o_val):
o_storage[i] = o_sub_val
else:
o_storage[0] = o_val
return outputs
thunk.inputs = thunk_inputs
thunk.outputs = thunk_outputs
thunk.lazy = False
thunks.append(thunk)
# This is a bit hackish, but we only return one of the output nodes
return thunks, output_nodes[:1]
def make_all(self, input_storage=None, output_storage=None, storage_map=None):
fgraph = self.fgraph
nodes = self.schedule(fgraph)
no_recycling = self.no_recycling
input_storage, output_storage, storage_map = map_storage(
fgraph, nodes, input_storage, output_storage, storage_map
)
compute_map = {}
for k in storage_map:
compute_map[k] = [k.owner is None]
thunks, nodes = self.create_jitable_thunk(
compute_map, nodes, input_storage, output_storage, storage_map
)
computed, last_user = gc_helper(nodes)
if self.allow_gc:
post_thunk_old_storage = []
for node in nodes:
post_thunk_old_storage.append(
[
storage_map[input]
for input in node.inputs
if (input in computed)
and (input not in fgraph.outputs)
and (node == last_user[input])
]
)
else:
post_thunk_old_storage = None
if no_recycling is True:
no_recycling = list(storage_map.values())
no_recycling = difference(no_recycling, input_storage)
else:
no_recycling = [
storage_map[r] for r in no_recycling if r not in fgraph.inputs
]
fn = streamline(
fgraph, thunks, nodes, post_thunk_old_storage, no_recycling=no_recycling
)
fn.allow_gc = self.allow_gc
fn.storage_map = storage_map
return (
fn,
[
Container(input, storage)
for input, storage in zip(fgraph.inputs, input_storage)
],
[
Container(output, storage, readonly=True)
for output, storage in zip(fgraph.outputs, output_storage)
],
thunks,
nodes,
)
from warnings import warn
from numpy.random import RandomState from numpy.random import RandomState
from aesara.graph.basic import Constant from aesara.graph.basic import Constant
from aesara.link.basic import Container, PerformLinker from aesara.link.basic import JITLinker
from aesara.link.utils import gc_helper, map_storage, streamline
from aesara.utils import difference
class JAXLinker(PerformLinker):
"""A `Linker` that JIT-compiles NumPy-based operations using JAX.
Attributes
----------
allow_non_jax: bool
A boolean indicating whether or not an exception is thrown when the
graph cannot be JAX compiled (e.g. the graph has an unsupported operator).
If `allow_non_jax` is `True`, the fallback is currently Python compilation.
""" class JAXLinker(JITLinker):
"""A `Linker` that JIT-compiles NumPy-based operations using JAX."""
allow_non_jax = False def fgraph_convert(
self, fgraph, order, input_storage, output_storage, storage_map, **kwargs
def create_jax_thunks(
self, compute_map, order, input_storage, output_storage, storage_map
): ):
"""Create a thunk for each output of the `Linker`s `FunctionGraph`. from aesara.link.jax.dispatch import jax_funcify
This is differs from the other thunk-making function in that it only
produces thunks for the `FunctionGraph` output nodes.
Parameters
----------
compute_map: dict
The compute map dictionary.
storage_map: dict
The storage map dictionary.
Returns return jax_funcify(
------- fgraph, order, input_storage, output_storage, storage_map, **kwargs
thunks: list )
A tuple containing the thunks.
output_nodes: list and their
A tuple containing the output nodes.
""" def jit_compile(self, fn):
import jax import jax
from aesara.link.jax.dispatch import jax_funcify, jax_typify
output_nodes = [o.owner for o in self.fgraph.outputs]
# Create a JAX-compilable function from our `FunctionGraph`
jaxed_fgraph = jax_funcify(
self.fgraph,
input_storage=input_storage,
output_storage=output_storage,
storage_map=storage_map,
)
# I suppose we can consider `Constant`s to be "static" according to # I suppose we can consider `Constant`s to be "static" according to
# JAX. # JAX.
static_argnums = [ static_argnums = [
n for n, i in enumerate(self.fgraph.inputs) if isinstance(i, Constant) n for n, i in enumerate(self.fgraph.inputs) if isinstance(i, Constant)
] ]
return jax.jit(fn, static_argnums)
def create_thunk_inputs(self, storage_map):
from aesara.link.jax.dispatch import jax_typify
thunk_inputs = [] thunk_inputs = []
for n in self.fgraph.inputs: for n in self.fgraph.inputs:
...@@ -79,121 +43,4 @@ class JAXLinker(PerformLinker): ...@@ -79,121 +43,4 @@ class JAXLinker(PerformLinker):
sinput = [new_value] sinput = [new_value]
thunk_inputs.append(sinput) thunk_inputs.append(sinput)
thunks = [] return thunk_inputs
thunk_outputs = [storage_map[n] for n in self.fgraph.outputs]
fgraph_jit = jax.jit(jaxed_fgraph, static_argnums)
def thunk(
fgraph=self.fgraph,
fgraph_jit=fgraph_jit,
thunk_inputs=thunk_inputs,
thunk_outputs=thunk_outputs,
):
outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
for o_node, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs):
compute_map[o_node][0] = True
if len(o_storage) > 1:
assert len(o_storage) == len(o_val)
for i, o_sub_val in enumerate(o_val):
o_storage[i] = o_sub_val
else:
o_storage[0] = o_val
return outputs
thunk.inputs = thunk_inputs
thunk.outputs = thunk_outputs
thunk.lazy = False
thunks.append(thunk)
# This is a bit hackish, but we only return one of the output nodes
return thunks, output_nodes[:1]
def make_all(self, input_storage=None, output_storage=None, storage_map=None):
fgraph = self.fgraph
nodes = self.schedule(fgraph)
no_recycling = self.no_recycling
input_storage, output_storage, storage_map = map_storage(
fgraph, nodes, input_storage, output_storage, storage_map
)
compute_map = {}
for k in storage_map:
compute_map[k] = [k.owner is None]
try:
# We need to create thunk functions that will populate the output
# storage arrays with the JAX-computed values.
thunks, nodes = self.create_jax_thunks(
compute_map, nodes, input_storage, output_storage, storage_map
)
except NotImplementedError as e:
if not self.allow_non_jax:
raise
warn(f"JaxLinker could not JAXify graph: {e}")
thunks = []
for node in nodes:
thunk = node.op.make_thunk(
node, storage_map, compute_map, no_recycling, "py"
)
thunk_inputs = [storage_map[v] for v in node.inputs]
thunk_outputs = [storage_map[v] for v in node.outputs]
thunk.inputs = thunk_inputs
thunk.outputs = thunk_outputs
thunks.append(thunk)
computed, last_user = gc_helper(nodes)
if self.allow_gc:
post_thunk_old_storage = []
for node in nodes:
post_thunk_old_storage.append(
[
storage_map[input]
for input in node.inputs
if (input in computed)
and (input not in fgraph.outputs)
and (node == last_user[input])
]
)
else:
post_thunk_old_storage = None
if no_recycling is True:
no_recycling = list(storage_map.values())
no_recycling = difference(no_recycling, input_storage)
else:
no_recycling = [
storage_map[r] for r in no_recycling if r not in fgraph.inputs
]
fn = streamline(
fgraph, thunks, nodes, post_thunk_old_storage, no_recycling=no_recycling
)
fn.allow_gc = self.allow_gc
fn.storage_map = storage_map
return (
fn,
[
Container(input, storage)
for input, storage in zip(fgraph.inputs, input_storage)
],
[
Container(output, storage, readonly=True)
for output, storage in zip(fgraph.outputs, output_storage)
],
thunks,
nodes,
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论