提交 0e4bf69b authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Introduce a JAX Linker class

Closes #10. Well, at least the JAX part, but Cython and Numba implementations can follow very similar approaches and accomplish the same thing.
上级 b2a5bd9b
...@@ -60,6 +60,7 @@ install: ...@@ -60,6 +60,7 @@ install:
- conda create --yes -q -n pyenv python=$TRAVIS_PYTHON_VERSION - conda create --yes -q -n pyenv python=$TRAVIS_PYTHON_VERSION
- conda activate pyenv - conda activate pyenv
- conda install --yes -q mkl numpy scipy pip mkl-service graphviz cython libgpuarray pygpu - conda install --yes -q mkl numpy scipy pip mkl-service graphviz cython libgpuarray pygpu
- if [[ "$TRAVIS_PYTHON_VERSION" != "3.6" ]]; then conda install --yes -q -c conda-forge 'jax<0.2.0' 'jaxlib'; fi
- pip install -q -r requirements.txt - pip install -q -r requirements.txt
- conda list && pip freeze - conda list && pip freeze
- python -c 'import theano; print(theano.config.__str__(print_doc=False))' - python -c 'import theano; print(theano.config.__str__(print_doc=False))'
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
flake8 flake8
pep8 pep8
pyflakes pyflakes
black==20.8b1 black==20.8b1; platform.python_implementation!='PyPy'
pytest-cov>=2.6.1 pytest-cov>=2.6.1
coverage>=5.1 coverage>=5.1
pytest pytest
...@@ -10,3 +10,5 @@ coveralls ...@@ -10,3 +10,5 @@ coveralls
cython cython
sympy sympy
versioneer versioneer
jax<0.2.0; python_version > '3.6'
jaxlib; python_version > '3.6'
差异被折叠。
...@@ -7,12 +7,13 @@ import logging ...@@ -7,12 +7,13 @@ import logging
import warnings import warnings
import theano import theano
from theano import gof
import theano.gof.vm import theano.gof.vm
from theano import config
from six import string_types from six import string_types
from theano.compile.function_module import Supervisor
from theano import config, gof
from theano.compile.function_module import Supervisor
from theano.sandbox.jax_linker import JAXLinker
_logger = logging.getLogger("theano.compile.mode") _logger = logging.getLogger("theano.compile.mode")
...@@ -29,6 +30,7 @@ predefined_linkers = { ...@@ -29,6 +30,7 @@ predefined_linkers = {
"cvm": gof.vm.VM_Linker(use_cloop=True), # Use allow_gc Theano flag "cvm": gof.vm.VM_Linker(use_cloop=True), # Use allow_gc Theano flag
"vm_nogc": gof.vm.VM_Linker(allow_gc=False, use_cloop=False), "vm_nogc": gof.vm.VM_Linker(allow_gc=False, use_cloop=False),
"cvm_nogc": gof.vm.VM_Linker(allow_gc=False, use_cloop=True), "cvm_nogc": gof.vm.VM_Linker(allow_gc=False, use_cloop=True),
"jax": JAXLinker(),
} }
......
from warnings import warn
from collections.abc import Sequence
from theano.gof.link import (
PerformLinker,
map_storage,
gc_helper,
utils,
add_clear_storage,
Container,
streamline,
)
from theano.gof.graph import Constant
class JAXLinker(PerformLinker):
"""A `Linker` that JIT-compiles NumPy-based operations using JAX.
Attributes
----------
allow_non_jax: bool
A boolean indicating whether or not an exception is thrown when the
graph cannot be JAX compiled (e.g. the graph has an unsupported operator).
If `allow_non_jax` is `True`, the fallback is currently Python compilation.
"""
allow_non_jax = False
def create_jax_thunks(self, compute_map, storage_map):
"""Create a thunk for each output of the `Linker`s `FunctionGraph`.
This is differs from the other thunk-making function in that it only
produces thunks for the `FunctionGraph` output nodes.
Parameters
----------
compute_map: dict
The compute map dictionary.
storage_map: dict
The storage map dictionary.
Returns
-------
thunks: list
A tuple containing the thunks.
output_nodes: list and their
A tuple containing the output nodes.
"""
import jax
from theano.sandbox.jaxify import jax_funcify
output_nodes = [o.owner for o in self.fgraph.outputs]
# Create a JAX-compilable function from our `FunctionGraph`
jaxed_fgraph_outputs = jax_funcify(self.fgraph)
assert len(jaxed_fgraph_outputs) == len(output_nodes)
# I suppose we can consider `Constant`s to be "static" according to
# JAX.
static_argnums = [
n for n, i in enumerate(self.fgraph.inputs) if isinstance(i, Constant)
]
thunk_inputs = [storage_map[n] for n in self.fgraph.inputs]
thunks = []
for node, jax_funcs in zip(output_nodes, jaxed_fgraph_outputs):
thunk_outputs = [storage_map[n] for n in node.outputs]
# JIT-compile the functions
if len(node.outputs) > 1:
assert len(jax_funcs) == len(node.ouptputs)
jax_impl_jits = [
jax.jit(jax_func, static_argnums) for jax_func in jax_funcs
]
else:
assert not isinstance(jax_funcs, Sequence)
jax_impl_jits = [jax.jit(jax_funcs, static_argnums)]
def thunk(
node=node, jax_impl_jits=jax_impl_jits, thunk_outputs=thunk_outputs
):
outputs = [
jax_impl_jit(*[x[0] for x in thunk_inputs])
for jax_impl_jit in jax_impl_jits
]
for o_node, o_storage, o_val in zip(
node.outputs, thunk_outputs, outputs
):
compute_map[o_node][0] = True
if len(o_storage) > 1:
assert len(o_storage) == len(o_val)
for i, o_sub_val in enumerate(o_val):
o_storage[i] = o_sub_val
else:
o_storage[0] = o_val
return outputs
thunk.inputs = thunk_inputs
thunk.outputs = thunk_outputs
thunk.lazy = False
thunks.append(thunk)
return thunks, output_nodes
def make_all(self, input_storage=None, output_storage=None, storage_map=None):
fgraph = self.fgraph
nodes = self.schedule(fgraph)
no_recycling = self.no_recycling
input_storage, output_storage, storage_map = map_storage(
fgraph, nodes, input_storage, output_storage, storage_map
)
compute_map = {}
for k in storage_map:
compute_map[k] = [k.owner is None]
try:
# We need to create thunk functions that will populate the output
# storage arrays with the JAX-computed values.
thunks, nodes = self.create_jax_thunks(compute_map, storage_map)
except NotImplementedError as e:
if not self.allow_non_jax:
raise
warn("JaxLinker could not JAXify graph: {}".format(e))
thunks = []
for node in nodes:
thunk = node.op.make_thunk(
node, storage_map, compute_map, no_recycling, "py"
)
thunk_inputs = [storage_map[v] for v in node.inputs]
thunk_outputs = [storage_map[v] for v in node.outputs]
thunk.inputs = thunk_inputs
thunk.outputs = thunk_outputs
thunks.append(thunk)
computed, last_user = gc_helper(nodes)
if self.allow_gc:
post_thunk_old_storage = []
for node in nodes:
post_thunk_old_storage.append(
[
storage_map[input]
for input in node.inputs
if (input in computed)
and (input not in fgraph.outputs)
and (node == last_user[input])
]
)
else:
post_thunk_old_storage = None
if no_recycling is True:
no_recycling = list(storage_map.values())
no_recycling = utils.difference(no_recycling, input_storage)
else:
no_recycling = [
storage_map[r] for r in no_recycling if r not in fgraph.inputs
]
fn = streamline(
fgraph, thunks, nodes, post_thunk_old_storage, no_recycling=no_recycling
)
fn.allow_gc = self.allow_gc
add_clear_storage(fn, computed, storage_map)
fn.storage_map = storage_map
return (
fn,
[
Container(input, storage)
for input, storage in zip(fgraph.inputs, input_storage)
],
[
Container(output, storage, True)
for output, storage in zip(fgraph.outputs, output_storage)
],
thunks,
nodes,
)
差异被折叠。
...@@ -2988,6 +2988,8 @@ class Inv(UnaryScalarOp): ...@@ -2988,6 +2988,8 @@ class Inv(UnaryScalarOp):
""" """
nfunc_spec = ("reciprocal", 1, 1)
def impl(self, x): def impl(self, x):
return np.float32(1.0) / x return np.float32(1.0) / x
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论