提交 c9333bcf authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Create flat functions via AST for JAXification of FunctionGraphs

上级 334c86fb
import ast
import re
import warnings
from collections.abc import Sequence
from functools import reduce, singledispatch, update_wrapper
from collections import Counter
from functools import reduce, singledispatch
from keyword import iskeyword
from tempfile import NamedTemporaryFile
from textwrap import indent
from types import FunctionType
from warnings import warn
import jax
......@@ -11,8 +17,10 @@ from numpy.random import RandomState
from aesara.compile.ops import DeepCopyOp, ViewOp
from aesara.configdefaults import config
from aesara.graph.basic import Constant, Variable
from aesara.graph.fg import FunctionGraph
from aesara.ifelse import IfElse
from aesara.link.utils import map_storage
from aesara.scalar.basic import Cast, Clip, Composite, Identity, ScalarOp, Second
from aesara.scan.op import Scan
from aesara.scan.utils import scan_args as ScanArgs
......@@ -95,102 +103,6 @@ subtensor_ops = (Subtensor, AdvancedSubtensor1, AdvancedSubtensor)
incsubtensor_ops = (IncSubtensor, AdvancedIncSubtensor1)
def compose_jax_funcs(out_node, fgraph_inputs, memo=None):
"""Compose JAX implementations of node operations.
This function walks the graph given by the `Apply` node, `out_node`, and
creates JAX JIT-able functions for its input and output variables.
Parameters
----------
out_node: aesara.graph.basic.Apply
The node for which we want to construct a JAX JIT-able function.
fgraph_inputs: List[Variable]
The inputs--in a `FunctionGraph` sense--to `out_node`.
memo: Mapping (Optional)
A map from visited nodes to their JAX functions.
Outputs
-------
A `function` object that represents the composed JAX operations and takes
the same form of inputs as `fgraph_inputs`.
"""
if memo is None:
memo = {}
if out_node in memo:
return memo[out_node]
jax_return_func = jax_funcify(out_node.op)
# We create a list of JAX-able functions that produce the values of each
# input variable for `out_node`.
input_funcs = []
for i in out_node.inputs:
if i in fgraph_inputs:
# This input is a top-level input (i.e. an input to the
# `FunctionGraph` in which this `out_node` resides)
idx = fgraph_inputs.index(i)
i_dtype = getattr(i, "dtype", None)
def jax_inputs_func(*inputs, i_dtype=i_dtype, idx=idx):
return jax_typify(inputs[idx], i_dtype)
input_f = jax_inputs_func
elif i.owner is None:
# This input is something like an `aesara.graph.basic.Constant`
i_dtype = getattr(i, "dtype", None)
i_data = i.data
def jax_data_func(*inputs, i_dtype=i_dtype, i_data=i_data):
return jax_typify(i_data, i_dtype)
input_f = jax_data_func
else:
# This input is the output of another node, so we need to
# generate a JAX-able function for its subgraph
input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
if i.owner.nout > 1:
# This input is one of multiple outputs from the `i.owner`
# node, and we need to determine exactly which one it is and
# create a JAX-able function that returns only it.
out_idx = i.owner.outputs.index(i)
(out_fn,) = input_f
def jax_multiout_func(*inputs, out_idx=out_idx, out_fn=out_fn):
return out_fn(*inputs)[out_idx]
input_f = jax_multiout_func
assert callable(input_f)
input_funcs.append(input_f)
if not isinstance(jax_return_func, Sequence):
jax_return_func = [jax_return_func]
jax_funcs = []
for return_func in jax_return_func:
def jax_func(*inputs):
func_args = [fn(*inputs) for fn in input_funcs]
return return_func(*func_args)
jax_funcs.append(update_wrapper(jax_func, return_func))
if len(out_node.outputs) == 1:
jax_funcs = jax_funcs[0]
memo[out_node] = jax_funcs
return jax_funcs
@singledispatch
def jax_typify(data, dtype):
"""Convert instances of Aesara `Type`s to JAX types."""
......@@ -213,7 +125,7 @@ def jax_typify_RandomState(state, dtype):
@singledispatch
def jax_funcify(op):
def jax_funcify(op, **kwargs):
"""Create a JAX compatible function from an Aesara `Op`."""
raise NotImplementedError(f"No JAX conversion for the given `Op`: {op}")
......@@ -458,8 +370,17 @@ def jax_funcify_Elemwise(op):
@jax_funcify.register(Composite)
def jax_funcify_Composite(op):
# This approach basically gets rid of the fused `Elemwise` by turning each
# `Op` in the `Composite` back into individually broadcasted NumPy-like
# operations.
# TODO: A better approach would involve something like `jax.vmap` or some
# other operation that can perform the broadcasting that `Elemwise` does.
jax_impl = jax_funcify(op.fgraph)
return jax_impl
def composite(*args):
return jax_impl(*args)[0]
return composite
@jax_funcify.register(Scan)
......@@ -684,12 +605,94 @@ def jax_funcify_AdvancedIncSubtensor(op):
@jax_funcify.register(FunctionGraph)
def jax_funcify_FunctionGraph(fgraph):
def jax_funcify_FunctionGraph(
fgraph, order=None, input_storage=None, output_storage=None, storage_map=None
):
if order is None:
order = fgraph.toposort()
input_storage, output_storage, storage_map = map_storage(
fgraph, order, input_storage, output_storage, storage_map
)
global_env = {}
fgraph_name = "jax_funcified_fgraph"
def unique_name(x, names_counter=Counter([fgraph_name]), obj_to_names={}):
if x in obj_to_names:
return obj_to_names[x]
if isinstance(x, Variable):
name = re.sub("[^0-9a-zA-Z]+", "_", x.name) if x.name else ""
name = (
name if (name.isidentifier() and not iskeyword(name)) else x.auto_name
)
elif isinstance(x, FunctionType):
name = x.__name__
else:
name = type(x).__name__
name_suffix = names_counter.get(name, "")
local_name = f"{name}{name_suffix}"
names_counter.update((name,))
obj_to_names[x] = local_name
return local_name
body_assigns = []
for node in order:
jax_func = jax_funcify(node.op)
# Create a local alias with a unique name
local_jax_func_name = unique_name(jax_func)
global_env[local_jax_func_name] = jax_func
node_input_names = []
for i in node.inputs:
local_input_name = unique_name(i)
if storage_map[i][0] is not None or isinstance(i, Constant):
# Constants need to be assigned locally and referenced
global_env[local_input_name] = jax_typify(storage_map[i][0], None)
# TODO: We could attempt to use the storage arrays directly
# E.g. `local_input_name = f"{local_input_name}[0]"`
node_input_names.append(local_input_name)
node_output_names = [unique_name(v) for v in node.outputs]
body_assigns.append(
f"{', '.join(node_output_names)} = {local_jax_func_name}({', '.join(node_input_names)})"
)
fgraph_input_names = [unique_name(v) for v in fgraph.inputs]
fgraph_output_names = [unique_name(v) for v in fgraph.outputs]
joined_body_assigns = indent("\n".join(body_assigns), " ")
if len(fgraph_output_names) == 1:
fgraph_return_src = f"({fgraph_output_names[0]},)"
else:
fgraph_return_src = ", ".join(fgraph_output_names)
fgraph_def_src = f"""
def {fgraph_name}({", ".join(fgraph_input_names)}):
{joined_body_assigns}
return {fgraph_return_src}
"""
fgraph_def_ast = ast.parse(fgraph_def_src)
# Create source code to be (at least temporarily) associated with the
# compiled function (e.g. for easier debugging)
with NamedTemporaryFile(delete=False) as f:
filename = f.name
f.write(fgraph_def_src.encode())
mod_code = compile(fgraph_def_ast, filename, mode="exec")
exec(mod_code, global_env, locals())
out_nodes = [r.owner for r in fgraph.outputs if r.owner is not None]
jax_funcs = [compose_jax_funcs(o, fgraph.inputs) for o in out_nodes]
fgraph_def = locals()[fgraph_name]
return jax_funcs
return fgraph_def
@jax_funcify.register(CAReduce)
......
from collections.abc import Sequence
from warnings import warn
from numpy.random import RandomState
......@@ -23,7 +22,9 @@ class JAXLinker(PerformLinker):
allow_non_jax = False
def create_jax_thunks(self, compute_map, storage_map):
def create_jax_thunks(
self, compute_map, order, input_storage, output_storage, storage_map
):
"""Create a thunk for each output of the `Linker`s `FunctionGraph`.
This is differs from the other thunk-making function in that it only
......@@ -51,9 +52,12 @@ class JAXLinker(PerformLinker):
output_nodes = [o.owner for o in self.fgraph.outputs]
# Create a JAX-compilable function from our `FunctionGraph`
jaxed_fgraph_outputs = jax_funcify(self.fgraph)
assert len(jaxed_fgraph_outputs) == len(output_nodes)
jaxed_fgraph = jax_funcify(
self.fgraph,
input_storage=input_storage,
output_storage=output_storage,
storage_map=storage_map,
)
# I suppose we can consider `Constant`s to be "static" according to
# JAX.
......@@ -75,52 +79,36 @@ class JAXLinker(PerformLinker):
thunks = []
for node, jax_funcs in zip(output_nodes, jaxed_fgraph_outputs):
thunk_outputs = [storage_map[n] for n in self.fgraph.outputs]
thunk_outputs = [storage_map[n] for n in node.outputs]
fgraph_jit = jax.jit(jaxed_fgraph, static_argnums)
if not isinstance(jax_funcs, Sequence):
jax_funcs = [jax_funcs]
def thunk(
fgraph=self.fgraph,
fgraph_jit=fgraph_jit,
thunk_inputs=thunk_inputs,
thunk_outputs=thunk_outputs,
):
outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
jax_impl_jits = [
jax.jit(jax_func, static_argnums) for jax_func in jax_funcs
]
for o_node, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs):
compute_map[o_node][0] = True
if len(o_storage) > 1:
assert len(o_storage) == len(o_val)
for i, o_sub_val in enumerate(o_val):
o_storage[i] = o_sub_val
else:
o_storage[0] = o_val
return outputs
thunk.inputs = thunk_inputs
thunk.outputs = thunk_outputs
thunk.lazy = False
thunks.append(thunk)
def thunk(
node=node, jax_impl_jits=jax_impl_jits, thunk_outputs=thunk_outputs
):
outputs = [
jax_impl_jit(*[x[0] for x in thunk_inputs])
for jax_impl_jit in jax_impl_jits
]
if len(jax_impl_jits) < len(node.outputs):
# In this case, the JAX function will output a single
# output that contains the other outputs.
# This happens for multi-output `Op`s that directly
# correspond to multi-output JAX functions (e.g. `SVD` and
# `jax.numpy.linalg.svd`).
outputs = outputs[0]
for o_node, o_storage, o_val in zip(
node.outputs, thunk_outputs, outputs
):
compute_map[o_node][0] = True
if len(o_storage) > 1:
assert len(o_storage) == len(o_val)
for i, o_sub_val in enumerate(o_val):
o_storage[i] = o_sub_val
else:
o_storage[0] = o_val
return outputs
thunk.inputs = thunk_inputs
thunk.outputs = thunk_outputs
thunk.lazy = False
thunks.append(thunk)
return thunks, output_nodes
# This is a bit hackish, but we only return one of the output nodes
return thunks, output_nodes[:1]
def make_all(self, input_storage=None, output_storage=None, storage_map=None):
fgraph = self.fgraph
......@@ -138,7 +126,9 @@ class JAXLinker(PerformLinker):
try:
# We need to create thunk functions that will populate the output
# storage arrays with the JAX-computed values.
thunks, nodes = self.create_jax_thunks(compute_map, storage_map)
thunks, nodes = self.create_jax_thunks(
compute_map, nodes, input_storage, output_storage, storage_map
)
except NotImplementedError as e:
if not self.allow_non_jax:
......
......@@ -10,11 +10,13 @@ from aesara.compile.mode import Mode
from aesara.compile.ops import DeepCopyOp, ViewOp
from aesara.compile.sharedvalue import SharedVariable, shared
from aesara.configdefaults import config
from aesara.graph.basic import Apply
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import get_test_value
from aesara.graph.op import Op, get_test_value
from aesara.graph.optdb import Query
from aesara.ifelse import ifelse
from aesara.link.jax import JAXLinker
from aesara.scalar.basic import Composite
from aesara.scan.basic import scan
from aesara.tensor import basic as aet
from aesara.tensor import blas as aet_blas
......@@ -24,6 +26,7 @@ from aesara.tensor import nlinalg as aet_nlinalg
from aesara.tensor import nnet as aet_nnet
from aesara.tensor import slinalg as aet_slinalg
from aesara.tensor import subtensor as aet_subtensor
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.math import MaxAndArgmax
from aesara.tensor.math import all as aet_all
from aesara.tensor.math import clip, cosh, gammaln, log
......@@ -295,6 +298,94 @@ def test_jax_basic():
)
def test_jax_Composite():
x_s = aes.float64("x")
y_s = aes.float64("y")
comp_op = Elemwise(Composite([x_s, y_s], [x_s + y_s * 2]))
x = vector("x")
y = vector("y")
out = comp_op(x, y)
out_fg = FunctionGraph([x, y], [out])
test_input_vals = [
np.arange(10).astype(config.floatX),
np.arange(10, 20).astype(config.floatX),
]
_ = compare_jax_and_py(out_fg, test_input_vals)
def test_jax_FunctionGraph_names():
import inspect
from aesara.link.jax.jax_dispatch import jax_funcify
x = scalar("1x")
y = scalar("_")
z = scalar()
q = scalar("def")
out_fg = FunctionGraph([x, y, z, q], [x, y, z, q], clone=False)
out_jx = jax_funcify(out_fg)
sig = inspect.signature(out_jx)
assert (x.auto_name, "_", z.auto_name, q.auto_name) == tuple(sig.parameters.keys())
assert (1, 2, 3, 4) == out_jx(1, 2, 3, 4)
def test_jax_FunctionGraph_once():
"""Make sure that an output is only computed once when it's referenced multiple times."""
from aesara.link.jax.jax_dispatch import jax_funcify
x = vector("x")
y = vector("y")
class TestOp(Op):
def __init__(self):
self.called = 0
def make_node(self, *args):
return Apply(self, list(args), [x.type() for x in args])
def perform(self, inputs, outputs):
for i, inp in enumerate(inputs):
outputs[i][0] = inp[0]
@jax_funcify.register(TestOp)
def jax_funcify_TestOp(op):
def func(*args, op=op):
op.called += 1
return list(args)
return func
op1 = TestOp()
op2 = TestOp()
q, r = op1(x, y)
outs = op2(q + r, q + r)
out_fg = FunctionGraph([x, y], outs, clone=False)
assert len(out_fg.outputs) == 2
out_jx = jax_funcify(out_fg)
x_val = np.r_[1, 2].astype(config.floatX)
y_val = np.r_[2, 3].astype(config.floatX)
res = out_jx(x_val, y_val)
assert len(res) == 2
assert op1.called == 1
assert op2.called == 1
res = out_jx(x_val, y_val)
assert len(res) == 2
assert op1.called == 2
assert op2.called == 2
def test_jax_eye():
"""Tests jaxification of the Eye operator"""
out = aet.eye(3)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论