提交 c9333bcf authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Create flat functions via AST for JAXification of FunctionGraphs

上级 334c86fb
import ast
import re
import warnings import warnings
from collections.abc import Sequence from collections import Counter
from functools import reduce, singledispatch, update_wrapper from functools import reduce, singledispatch
from keyword import iskeyword
from tempfile import NamedTemporaryFile
from textwrap import indent
from types import FunctionType
from warnings import warn from warnings import warn
import jax import jax
...@@ -11,8 +17,10 @@ from numpy.random import RandomState ...@@ -11,8 +17,10 @@ from numpy.random import RandomState
from aesara.compile.ops import DeepCopyOp, ViewOp from aesara.compile.ops import DeepCopyOp, ViewOp
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Constant, Variable
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.ifelse import IfElse from aesara.ifelse import IfElse
from aesara.link.utils import map_storage
from aesara.scalar.basic import Cast, Clip, Composite, Identity, ScalarOp, Second from aesara.scalar.basic import Cast, Clip, Composite, Identity, ScalarOp, Second
from aesara.scan.op import Scan from aesara.scan.op import Scan
from aesara.scan.utils import scan_args as ScanArgs from aesara.scan.utils import scan_args as ScanArgs
...@@ -95,102 +103,6 @@ subtensor_ops = (Subtensor, AdvancedSubtensor1, AdvancedSubtensor) ...@@ -95,102 +103,6 @@ subtensor_ops = (Subtensor, AdvancedSubtensor1, AdvancedSubtensor)
incsubtensor_ops = (IncSubtensor, AdvancedIncSubtensor1) incsubtensor_ops = (IncSubtensor, AdvancedIncSubtensor1)
def compose_jax_funcs(out_node, fgraph_inputs, memo=None):
"""Compose JAX implementations of node operations.
This function walks the graph given by the `Apply` node, `out_node`, and
creates JAX JIT-able functions for its input and output variables.
Parameters
----------
out_node: aesara.graph.basic.Apply
The node for which we want to construct a JAX JIT-able function.
fgraph_inputs: List[Variable]
The inputs--in a `FunctionGraph` sense--to `out_node`.
memo: Mapping (Optional)
A map from visited nodes to their JAX functions.
Outputs
-------
A `function` object that represents the composed JAX operations and takes
the same form of inputs as `fgraph_inputs`.
"""
if memo is None:
memo = {}
if out_node in memo:
return memo[out_node]
jax_return_func = jax_funcify(out_node.op)
# We create a list of JAX-able functions that produce the values of each
# input variable for `out_node`.
input_funcs = []
for i in out_node.inputs:
if i in fgraph_inputs:
# This input is a top-level input (i.e. an input to the
# `FunctionGraph` in which this `out_node` resides)
idx = fgraph_inputs.index(i)
i_dtype = getattr(i, "dtype", None)
def jax_inputs_func(*inputs, i_dtype=i_dtype, idx=idx):
return jax_typify(inputs[idx], i_dtype)
input_f = jax_inputs_func
elif i.owner is None:
# This input is something like an `aesara.graph.basic.Constant`
i_dtype = getattr(i, "dtype", None)
i_data = i.data
def jax_data_func(*inputs, i_dtype=i_dtype, i_data=i_data):
return jax_typify(i_data, i_dtype)
input_f = jax_data_func
else:
# This input is the output of another node, so we need to
# generate a JAX-able function for its subgraph
input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
if i.owner.nout > 1:
# This input is one of multiple outputs from the `i.owner`
# node, and we need to determine exactly which one it is and
# create a JAX-able function that returns only it.
out_idx = i.owner.outputs.index(i)
(out_fn,) = input_f
def jax_multiout_func(*inputs, out_idx=out_idx, out_fn=out_fn):
return out_fn(*inputs)[out_idx]
input_f = jax_multiout_func
assert callable(input_f)
input_funcs.append(input_f)
if not isinstance(jax_return_func, Sequence):
jax_return_func = [jax_return_func]
jax_funcs = []
for return_func in jax_return_func:
def jax_func(*inputs):
func_args = [fn(*inputs) for fn in input_funcs]
return return_func(*func_args)
jax_funcs.append(update_wrapper(jax_func, return_func))
if len(out_node.outputs) == 1:
jax_funcs = jax_funcs[0]
memo[out_node] = jax_funcs
return jax_funcs
@singledispatch @singledispatch
def jax_typify(data, dtype): def jax_typify(data, dtype):
"""Convert instances of Aesara `Type`s to JAX types.""" """Convert instances of Aesara `Type`s to JAX types."""
...@@ -213,7 +125,7 @@ def jax_typify_RandomState(state, dtype): ...@@ -213,7 +125,7 @@ def jax_typify_RandomState(state, dtype):
@singledispatch @singledispatch
def jax_funcify(op): def jax_funcify(op, **kwargs):
"""Create a JAX compatible function from an Aesara `Op`.""" """Create a JAX compatible function from an Aesara `Op`."""
raise NotImplementedError(f"No JAX conversion for the given `Op`: {op}") raise NotImplementedError(f"No JAX conversion for the given `Op`: {op}")
...@@ -458,8 +370,17 @@ def jax_funcify_Elemwise(op): ...@@ -458,8 +370,17 @@ def jax_funcify_Elemwise(op):
@jax_funcify.register(Composite) @jax_funcify.register(Composite)
def jax_funcify_Composite(op): def jax_funcify_Composite(op):
# This approach basically gets rid of the fused `Elemwise` by turning each
# `Op` in the `Composite` back into individually broadcasted NumPy-like
# operations.
# TODO: A better approach would involve something like `jax.vmap` or some
# other operation that can perform the broadcasting that `Elemwise` does.
jax_impl = jax_funcify(op.fgraph) jax_impl = jax_funcify(op.fgraph)
return jax_impl
def composite(*args):
return jax_impl(*args)[0]
return composite
@jax_funcify.register(Scan) @jax_funcify.register(Scan)
...@@ -684,12 +605,94 @@ def jax_funcify_AdvancedIncSubtensor(op): ...@@ -684,12 +605,94 @@ def jax_funcify_AdvancedIncSubtensor(op):
@jax_funcify.register(FunctionGraph) @jax_funcify.register(FunctionGraph)
def jax_funcify_FunctionGraph(fgraph): def jax_funcify_FunctionGraph(
fgraph, order=None, input_storage=None, output_storage=None, storage_map=None
):
if order is None:
order = fgraph.toposort()
input_storage, output_storage, storage_map = map_storage(
fgraph, order, input_storage, output_storage, storage_map
)
global_env = {}
fgraph_name = "jax_funcified_fgraph"
def unique_name(x, names_counter=Counter([fgraph_name]), obj_to_names={}):
if x in obj_to_names:
return obj_to_names[x]
if isinstance(x, Variable):
name = re.sub("[^0-9a-zA-Z]+", "_", x.name) if x.name else ""
name = (
name if (name.isidentifier() and not iskeyword(name)) else x.auto_name
)
elif isinstance(x, FunctionType):
name = x.__name__
else:
name = type(x).__name__
name_suffix = names_counter.get(name, "")
local_name = f"{name}{name_suffix}"
names_counter.update((name,))
obj_to_names[x] = local_name
return local_name
body_assigns = []
for node in order:
jax_func = jax_funcify(node.op)
# Create a local alias with a unique name
local_jax_func_name = unique_name(jax_func)
global_env[local_jax_func_name] = jax_func
node_input_names = []
for i in node.inputs:
local_input_name = unique_name(i)
if storage_map[i][0] is not None or isinstance(i, Constant):
# Constants need to be assigned locally and referenced
global_env[local_input_name] = jax_typify(storage_map[i][0], None)
# TODO: We could attempt to use the storage arrays directly
# E.g. `local_input_name = f"{local_input_name}[0]"`
node_input_names.append(local_input_name)
node_output_names = [unique_name(v) for v in node.outputs]
body_assigns.append(
f"{', '.join(node_output_names)} = {local_jax_func_name}({', '.join(node_input_names)})"
)
fgraph_input_names = [unique_name(v) for v in fgraph.inputs]
fgraph_output_names = [unique_name(v) for v in fgraph.outputs]
joined_body_assigns = indent("\n".join(body_assigns), " ")
if len(fgraph_output_names) == 1:
fgraph_return_src = f"({fgraph_output_names[0]},)"
else:
fgraph_return_src = ", ".join(fgraph_output_names)
fgraph_def_src = f"""
def {fgraph_name}({", ".join(fgraph_input_names)}):
{joined_body_assigns}
return {fgraph_return_src}
"""
fgraph_def_ast = ast.parse(fgraph_def_src)
# Create source code to be (at least temporarily) associated with the
# compiled function (e.g. for easier debugging)
with NamedTemporaryFile(delete=False) as f:
filename = f.name
f.write(fgraph_def_src.encode())
mod_code = compile(fgraph_def_ast, filename, mode="exec")
exec(mod_code, global_env, locals())
out_nodes = [r.owner for r in fgraph.outputs if r.owner is not None] fgraph_def = locals()[fgraph_name]
jax_funcs = [compose_jax_funcs(o, fgraph.inputs) for o in out_nodes]
return jax_funcs return fgraph_def
@jax_funcify.register(CAReduce) @jax_funcify.register(CAReduce)
......
from collections.abc import Sequence
from warnings import warn from warnings import warn
from numpy.random import RandomState from numpy.random import RandomState
...@@ -23,7 +22,9 @@ class JAXLinker(PerformLinker): ...@@ -23,7 +22,9 @@ class JAXLinker(PerformLinker):
allow_non_jax = False allow_non_jax = False
def create_jax_thunks(self, compute_map, storage_map): def create_jax_thunks(
self, compute_map, order, input_storage, output_storage, storage_map
):
"""Create a thunk for each output of the `Linker`s `FunctionGraph`. """Create a thunk for each output of the `Linker`s `FunctionGraph`.
This is differs from the other thunk-making function in that it only This is differs from the other thunk-making function in that it only
...@@ -51,9 +52,12 @@ class JAXLinker(PerformLinker): ...@@ -51,9 +52,12 @@ class JAXLinker(PerformLinker):
output_nodes = [o.owner for o in self.fgraph.outputs] output_nodes = [o.owner for o in self.fgraph.outputs]
# Create a JAX-compilable function from our `FunctionGraph` # Create a JAX-compilable function from our `FunctionGraph`
jaxed_fgraph_outputs = jax_funcify(self.fgraph) jaxed_fgraph = jax_funcify(
self.fgraph,
assert len(jaxed_fgraph_outputs) == len(output_nodes) input_storage=input_storage,
output_storage=output_storage,
storage_map=storage_map,
)
# I suppose we can consider `Constant`s to be "static" according to # I suppose we can consider `Constant`s to be "static" according to
# JAX. # JAX.
...@@ -75,36 +79,19 @@ class JAXLinker(PerformLinker): ...@@ -75,36 +79,19 @@ class JAXLinker(PerformLinker):
thunks = [] thunks = []
for node, jax_funcs in zip(output_nodes, jaxed_fgraph_outputs): thunk_outputs = [storage_map[n] for n in self.fgraph.outputs]
thunk_outputs = [storage_map[n] for n in node.outputs]
if not isinstance(jax_funcs, Sequence):
jax_funcs = [jax_funcs]
jax_impl_jits = [ fgraph_jit = jax.jit(jaxed_fgraph, static_argnums)
jax.jit(jax_func, static_argnums) for jax_func in jax_funcs
]
def thunk( def thunk(
node=node, jax_impl_jits=jax_impl_jits, thunk_outputs=thunk_outputs fgraph=self.fgraph,
fgraph_jit=fgraph_jit,
thunk_inputs=thunk_inputs,
thunk_outputs=thunk_outputs,
): ):
outputs = [ outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
jax_impl_jit(*[x[0] for x in thunk_inputs])
for jax_impl_jit in jax_impl_jits
]
if len(jax_impl_jits) < len(node.outputs):
# In this case, the JAX function will output a single
# output that contains the other outputs.
# This happens for multi-output `Op`s that directly
# correspond to multi-output JAX functions (e.g. `SVD` and
# `jax.numpy.linalg.svd`).
outputs = outputs[0]
for o_node, o_storage, o_val in zip( for o_node, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs):
node.outputs, thunk_outputs, outputs
):
compute_map[o_node][0] = True compute_map[o_node][0] = True
if len(o_storage) > 1: if len(o_storage) > 1:
assert len(o_storage) == len(o_val) assert len(o_storage) == len(o_val)
...@@ -120,7 +107,8 @@ class JAXLinker(PerformLinker): ...@@ -120,7 +107,8 @@ class JAXLinker(PerformLinker):
thunks.append(thunk) thunks.append(thunk)
return thunks, output_nodes # This is a bit hackish, but we only return one of the output nodes
return thunks, output_nodes[:1]
def make_all(self, input_storage=None, output_storage=None, storage_map=None): def make_all(self, input_storage=None, output_storage=None, storage_map=None):
fgraph = self.fgraph fgraph = self.fgraph
...@@ -138,7 +126,9 @@ class JAXLinker(PerformLinker): ...@@ -138,7 +126,9 @@ class JAXLinker(PerformLinker):
try: try:
# We need to create thunk functions that will populate the output # We need to create thunk functions that will populate the output
# storage arrays with the JAX-computed values. # storage arrays with the JAX-computed values.
thunks, nodes = self.create_jax_thunks(compute_map, storage_map) thunks, nodes = self.create_jax_thunks(
compute_map, nodes, input_storage, output_storage, storage_map
)
except NotImplementedError as e: except NotImplementedError as e:
if not self.allow_non_jax: if not self.allow_non_jax:
......
...@@ -10,11 +10,13 @@ from aesara.compile.mode import Mode ...@@ -10,11 +10,13 @@ from aesara.compile.mode import Mode
from aesara.compile.ops import DeepCopyOp, ViewOp from aesara.compile.ops import DeepCopyOp, ViewOp
from aesara.compile.sharedvalue import SharedVariable, shared from aesara.compile.sharedvalue import SharedVariable, shared
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Apply
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.op import get_test_value from aesara.graph.op import Op, get_test_value
from aesara.graph.optdb import Query from aesara.graph.optdb import Query
from aesara.ifelse import ifelse from aesara.ifelse import ifelse
from aesara.link.jax import JAXLinker from aesara.link.jax import JAXLinker
from aesara.scalar.basic import Composite
from aesara.scan.basic import scan from aesara.scan.basic import scan
from aesara.tensor import basic as aet from aesara.tensor import basic as aet
from aesara.tensor import blas as aet_blas from aesara.tensor import blas as aet_blas
...@@ -24,6 +26,7 @@ from aesara.tensor import nlinalg as aet_nlinalg ...@@ -24,6 +26,7 @@ from aesara.tensor import nlinalg as aet_nlinalg
from aesara.tensor import nnet as aet_nnet from aesara.tensor import nnet as aet_nnet
from aesara.tensor import slinalg as aet_slinalg from aesara.tensor import slinalg as aet_slinalg
from aesara.tensor import subtensor as aet_subtensor from aesara.tensor import subtensor as aet_subtensor
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.math import MaxAndArgmax from aesara.tensor.math import MaxAndArgmax
from aesara.tensor.math import all as aet_all from aesara.tensor.math import all as aet_all
from aesara.tensor.math import clip, cosh, gammaln, log from aesara.tensor.math import clip, cosh, gammaln, log
...@@ -295,6 +298,94 @@ def test_jax_basic(): ...@@ -295,6 +298,94 @@ def test_jax_basic():
) )
def test_jax_Composite():
x_s = aes.float64("x")
y_s = aes.float64("y")
comp_op = Elemwise(Composite([x_s, y_s], [x_s + y_s * 2]))
x = vector("x")
y = vector("y")
out = comp_op(x, y)
out_fg = FunctionGraph([x, y], [out])
test_input_vals = [
np.arange(10).astype(config.floatX),
np.arange(10, 20).astype(config.floatX),
]
_ = compare_jax_and_py(out_fg, test_input_vals)
def test_jax_FunctionGraph_names():
import inspect
from aesara.link.jax.jax_dispatch import jax_funcify
x = scalar("1x")
y = scalar("_")
z = scalar()
q = scalar("def")
out_fg = FunctionGraph([x, y, z, q], [x, y, z, q], clone=False)
out_jx = jax_funcify(out_fg)
sig = inspect.signature(out_jx)
assert (x.auto_name, "_", z.auto_name, q.auto_name) == tuple(sig.parameters.keys())
assert (1, 2, 3, 4) == out_jx(1, 2, 3, 4)
def test_jax_FunctionGraph_once():
"""Make sure that an output is only computed once when it's referenced multiple times."""
from aesara.link.jax.jax_dispatch import jax_funcify
x = vector("x")
y = vector("y")
class TestOp(Op):
def __init__(self):
self.called = 0
def make_node(self, *args):
return Apply(self, list(args), [x.type() for x in args])
def perform(self, inputs, outputs):
for i, inp in enumerate(inputs):
outputs[i][0] = inp[0]
@jax_funcify.register(TestOp)
def jax_funcify_TestOp(op):
def func(*args, op=op):
op.called += 1
return list(args)
return func
op1 = TestOp()
op2 = TestOp()
q, r = op1(x, y)
outs = op2(q + r, q + r)
out_fg = FunctionGraph([x, y], outs, clone=False)
assert len(out_fg.outputs) == 2
out_jx = jax_funcify(out_fg)
x_val = np.r_[1, 2].astype(config.floatX)
y_val = np.r_[2, 3].astype(config.floatX)
res = out_jx(x_val, y_val)
assert len(res) == 2
assert op1.called == 1
assert op2.called == 1
res = out_jx(x_val, y_val)
assert len(res) == 2
assert op1.called == 2
assert op2.called == 2
def test_jax_eye(): def test_jax_eye():
"""Tests jaxification of the Eye operator""" """Tests jaxification of the Eye operator"""
out = aet.eye(3) out = aet.eye(3)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论