提交 0e4bf69b authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Introduce a JAX Linker class

Closes #10. Well, at least the JAX part, but Cython and Numba implementations can follow very similar approaches and accomplish the same thing.
上级 b2a5bd9b
......@@ -60,6 +60,7 @@ install:
- conda create --yes -q -n pyenv python=$TRAVIS_PYTHON_VERSION
- conda activate pyenv
- conda install --yes -q mkl numpy scipy pip mkl-service graphviz cython libgpuarray pygpu
- if [[ "$TRAVIS_PYTHON_VERSION" != "3.6" ]]; then conda install --yes -q -c conda-forge 'jax<0.2.0' 'jaxlib'; fi
- pip install -q -r requirements.txt
- conda list && pip freeze
- python -c 'import theano; print(theano.config.__str__(print_doc=False))'
......
......@@ -2,7 +2,7 @@
flake8
pep8
pyflakes
black==20.8b1
black==20.8b1; platform.python_implementation!='PyPy'
pytest-cov>=2.6.1
coverage>=5.1
pytest
......@@ -10,3 +10,5 @@ coveralls
cython
sympy
versioneer
jax<0.2.0; python_version > '3.6'
jaxlib; python_version > '3.6'
import pytest
import numpy as np
import theano
import theano.tensor as tt
jax = pytest.importorskip("jax")
from theano.gof.op import get_test_value # noqa: E402
@pytest.fixture(scope="module", autouse=True)
def set_theano_flags():
with theano.change_flags(cxx="", compute_test_value="warn"):
yield
def compare_jax_and_py(fgraph, inputs, cmp_fn=np.allclose):
jax_mode = theano.compile.Mode(linker="jax")
theano_jax_fn = theano.function(fgraph.inputs, fgraph.outputs, mode=jax_mode)
jax_res = theano_jax_fn(*inputs)
if isinstance(jax_res, list):
assert all(isinstance(res, jax.interpreters.xla.DeviceArray) for res in jax_res)
else:
assert isinstance(jax_res, jax.interpreters.xla.DeviceArray)
py_mode = theano.compile.Mode(linker="py")
theano_py_fn = theano.function(fgraph.inputs, fgraph.outputs, mode=py_mode)
py_res = theano_py_fn(*inputs)
assert cmp_fn(jax_res, py_res)
return jax_res
def test_jax_Alloc():
x = tt.alloc(0.0, 2, 3)
x_fg = theano.gof.FunctionGraph([], [x])
(jax_res,) = compare_jax_and_py(x_fg, [])
assert jax_res.shape == (2, 3)
x = tt.alloc(1.1, 2, 3)
x_fg = theano.gof.FunctionGraph([], [x])
compare_jax_and_py(x_fg, [])
x = theano.tensor.basic.AllocEmpty("float32")(2, 3)
x_fg = theano.gof.FunctionGraph([], [x])
def compare_shape_dtype(x, y):
(x,) = x
(y,) = y
return x.shape == y.shape and x.dtype == y.dtype
(jax_res,) = compare_jax_and_py(x_fg, [], cmp_fn=compare_shape_dtype)
a = tt.scalar("a")
x = tt.alloc(a, 20)
x_fg = theano.gof.FunctionGraph([a], [x])
(jax_res,) = compare_jax_and_py(x_fg, [10.0])
a = tt.vector("a")
x = tt.alloc(a, 20, 10)
x_fg = theano.gof.FunctionGraph([a], [x])
(jax_res,) = compare_jax_and_py(x_fg, [np.ones(10, dtype=tt.config.floatX)])
def test_jax_compile_ops():
x = theano.compile.ops.DeepCopyOp()(tt.as_tensor_variable(1.1))
x_fg = theano.gof.FunctionGraph([], [x])
(jax_res,) = compare_jax_and_py(x_fg, [])
x_np = np.zeros((20, 3))
x = theano.compile.ops.Shape()(tt.as_tensor_variable(x_np))
x_fg = theano.gof.FunctionGraph([], [x])
(jax_res,) = compare_jax_and_py(x_fg, [])
x = theano.compile.ops.Shape_i(1)(tt.as_tensor_variable(x_np))
x_fg = theano.gof.FunctionGraph([], [x])
(jax_res,) = compare_jax_and_py(x_fg, [])
x = theano.compile.ops.SpecifyShape()(tt.as_tensor_variable(x_np), (20, 3))
x_fg = theano.gof.FunctionGraph([], [x])
(jax_res,) = compare_jax_and_py(x_fg, [])
with theano.change_flags(compute_test_value="off"):
x = theano.compile.ops.SpecifyShape()(tt.as_tensor_variable(x_np), (2, 3))
x_fg = theano.gof.FunctionGraph([], [x])
with pytest.raises(AssertionError):
(jax_res,) = compare_jax_and_py(x_fg, [])
x_np = np.zeros((20, 1, 1))
x = theano.compile.ops.Rebroadcast((0, False), (1, True), (2, False))(
tt.as_tensor_variable(x_np)
)
x_fg = theano.gof.FunctionGraph([], [x])
(jax_res,) = compare_jax_and_py(x_fg, [])
with theano.change_flags(compute_test_value="off"):
x = theano.compile.ops.Rebroadcast((0, True), (1, False), (2, False))(
tt.as_tensor_variable(x_np)
)
x_fg = theano.gof.FunctionGraph([], [x])
with pytest.raises(ValueError):
(jax_res,) = compare_jax_and_py(x_fg, [])
x = theano.compile.ops.ViewOp()(tt.as_tensor_variable(x_np))
x_fg = theano.gof.FunctionGraph([], [x])
(jax_res,) = compare_jax_and_py(x_fg, [])
def test_jax_basic():
x = tt.matrix("x")
y = tt.matrix("y")
# `ScalarOp`
z = tt.cosh(x ** 2 + y / 3.0)
# `[Inc]Subtensor`
out = tt.set_subtensor(z[0], -10.0)
out = tt.inc_subtensor(out[0, 1], 2.0)
out = out[:5, :3]
out_fg = theano.gof.FunctionGraph([x, y], [out])
test_input_vals = [
np.tile(np.arange(10), (10, 1)).astype(tt.config.floatX),
np.tile(np.arange(10, 20), (10, 1)).astype(tt.config.floatX),
]
(jax_res,) = compare_jax_and_py(out_fg, test_input_vals)
# Confirm that the `Subtensor` slice operations are correct
assert jax_res.shape == (5, 3)
# Confirm that the `IncSubtensor` operations are correct
assert jax_res[0, 0] == -10.0
assert jax_res[0, 1] == -8.0
@pytest.mark.skip(reason="Not fully implemented, yet.")
def test_jax_scan():
theano.config.compute_test_value = "raise"
a_tt = tt.scalar("a")
a_tt.tag.test_value = 3.0
def input_step_fn(y_tm1, y_tm2, a):
y_tm1.name = "y_tm1"
y_tm2.name = "y_tm2"
res = (y_tm1 + y_tm2) * a
res.name = "y_t"
return res
y_scan_tt, _ = theano.scan(
fn=input_step_fn,
outputs_info=[
{
"initial": tt.as_tensor_variable(
np.r_[-1.0, 0.0].astype(tt.config.floatX)
),
"taps": [-1, -2],
},
],
non_sequences=[a_tt],
n_steps=10,
name="y_scan",
)
y_scan_tt.name = "y"
y_scan_tt.owner.inputs[0].name = "y_all"
theano_scan_fn = theano.function([], y_scan_tt, givens={a_tt: 3.0})
theano_res = theano_scan_fn()
#
# The equivalent JAX `scan`:
#
import jax
import jax.numpy as jnp
def jax_inner_scan(carry, x):
(y_tm1, y_tm2), a = carry
res = (y_tm1 + y_tm2) * a
return [jnp.array([res, y_tm1]), a], res
init_carry = [np.r_[0.0, -1.0].astype(tt.config.floatX), 3.0]
tmp, jax_res = jax.lax.scan(jax_inner_scan, init_carry, None, length=10)
assert np.allclose(jax_res, theano_res)
out_fg = theano.gof.FunctionGraph([a_tt], [y_scan_tt])
test_input_vals = [np.array(10.0).astype(tt.config.floatX)]
(jax_res,) = compare_jax_and_py(out_fg, test_input_vals)
assert False
def test_jax_Subtensors():
# Basic indices
x_tt = tt.arange(3 * 4 * 5).reshape((3, 4, 5))
out_tt = x_tt[1, 2, 0]
out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, [])
out_tt = x_tt[1:2, 1, :]
out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, [])
# Boolean indices
out_tt = x_tt[x_tt < 0]
out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, [])
# Advanced indexing
out_tt = x_tt[[1, 2]]
out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, [])
out_tt = x_tt[[1, 2], [2, 3]]
out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, [])
# Advanced and basic indexing
out_tt = x_tt[[1, 2], :]
out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, [])
out_tt = x_tt[[1, 2], :, [3, 4]]
out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, [])
def test_jax_IncSubtensor():
x_np = np.empty((3, 4, 5), dtype=tt.config.floatX)
x_tt = tt.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(tt.config.floatX)
# "Set" basic indices
st_tt = tt.as_tensor_variable(np.array(-10.0, dtype=tt.config.floatX))
out_tt = tt.set_subtensor(x_tt[1, 2, 3], st_tt)
out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, [])
st_tt = tt.as_tensor_variable(np.r_[-1.0, 0.0].astype(tt.config.floatX))
out_tt = tt.set_subtensor(x_tt[:2, 0, 0], st_tt)
out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, [])
out_tt = tt.set_subtensor(x_tt[0, 1:3, 0], st_tt)
out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, [])
# "Set" advanced indices
st_tt = tt.as_tensor_variable(np.r_[-1.0, 0.0].astype(tt.config.floatX))
out_tt = tt.set_subtensor(x_tt[[0, 2], 0, 0], st_tt)
out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, [])
st_tt = tt.as_tensor_variable(x_np[[0, 2], 0, :3])
out_tt = tt.set_subtensor(x_tt[[0, 2], 0, :3], st_tt)
out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, [])
# "Set" boolean indices
mask_tt = tt.as_tensor_variable(x_np) > 0
out_tt = tt.set_subtensor(x_tt[mask_tt], 0.0)
out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, [])
# "Increment" basic indices
st_tt = tt.as_tensor_variable(np.array(-10.0, dtype=tt.config.floatX))
out_tt = tt.inc_subtensor(x_tt[1, 2, 3], st_tt)
out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, [])
st_tt = tt.as_tensor_variable(np.r_[-1.0, 0.0].astype(tt.config.floatX))
out_tt = tt.inc_subtensor(x_tt[:2, 0, 0], st_tt)
out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, [])
out_tt = tt.set_subtensor(x_tt[0, 1:3, 0], st_tt)
out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, [])
# "Increment" advanced indices
st_tt = tt.as_tensor_variable(np.r_[-1.0, 0.0].astype(tt.config.floatX))
out_tt = tt.inc_subtensor(x_tt[[0, 2], 0, 0], st_tt)
out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, [])
st_tt = tt.as_tensor_variable(x_np[[0, 2], 0, :3])
out_tt = tt.inc_subtensor(x_tt[[0, 2], 0, :3], st_tt)
out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, [])
# "Increment" boolean indices
mask_tt = tt.as_tensor_variable(x_np) > 0
out_tt = tt.set_subtensor(x_tt[mask_tt], 1.0)
out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, [])
def test_jax_ifelse():
true_vals = np.r_[1, 2, 3]
false_vals = np.r_[-1, -2, -3]
x = theano.ifelse.ifelse(np.array(True), true_vals, false_vals)
x_fg = theano.gof.FunctionGraph([], [x])
(jax_res,) = compare_jax_and_py(x_fg, [])
x = theano.ifelse.ifelse(np.array(False), true_vals, false_vals)
x_fg = theano.gof.FunctionGraph([], [x])
(jax_res,) = compare_jax_and_py(x_fg, [])
def test_jax_CAReduce():
a_tt = tt.vector("a")
a_tt.tag.test_value = np.r_[1, 2, 3].astype(tt.config.floatX)
x = tt.sum(a_tt, axis=None)
x_fg = theano.gof.FunctionGraph([a_tt], [x])
_ = compare_jax_and_py(x_fg, [np.r_[1, 2, 3].astype(tt.config.floatX)])
a_tt = tt.matrix("a")
a_tt.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(tt.config.floatX)
x = tt.sum(a_tt, axis=0)
x_fg = theano.gof.FunctionGraph([a_tt], [x])
_ = compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(tt.config.floatX)])
x = tt.sum(a_tt, axis=1)
x_fg = theano.gof.FunctionGraph([a_tt], [x])
_ = compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(tt.config.floatX)])
a_tt = tt.matrix("a")
a_tt.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(tt.config.floatX)
x = tt.prod(a_tt, axis=0)
x_fg = theano.gof.FunctionGraph([a_tt], [x])
_ = compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(tt.config.floatX)])
x = tt.all(a_tt)
x_fg = theano.gof.FunctionGraph([a_tt], [x])
_ = compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(tt.config.floatX)])
def test_jax_MakeVector():
x = tt.opt.make_vector(1, 2, 3)
x_fg = theano.gof.FunctionGraph([], [x])
_ = compare_jax_and_py(x_fg, [])
def test_jax_Reshape():
a_tt = tt.vector("a")
x = tt.basic.reshape(a_tt, (2, 2))
x_fg = theano.gof.FunctionGraph([a_tt], [x])
_ = compare_jax_and_py(
x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(theano.config.floatX)]
)
def test_jax_Reshape_omnistaging():
# Test breaking "omnistaging" changes in JAX.
# See https://github.com/tensorflow/probability/commit/782d0c64eb774b9aac54a1c8488e4f1f96fbbc68
a_tt = tt.vector("a")
x = tt.basic.reshape(a_tt, (a_tt.shape[0] // 2, a_tt.shape[0] // 3))
x_fg = theano.gof.FunctionGraph([a_tt], [x])
_ = compare_jax_and_py(x_fg, [np.empty((6,)).astype(theano.config.floatX)])
def test_jax_Dimshuffle():
a_tt = tt.matrix("a")
x = a_tt.T
x_fg = theano.gof.FunctionGraph([a_tt], [x])
_ = compare_jax_and_py(
x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(tt.config.floatX)]
)
x = a_tt.dimshuffle([0, 1, "x"])
x_fg = theano.gof.FunctionGraph([a_tt], [x])
_ = compare_jax_and_py(
x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(tt.config.floatX)]
)
a_tt = tt.tensor(dtype=tt.config.floatX, broadcastable=[False, True])
x = a_tt.dimshuffle((0,))
x_fg = theano.gof.FunctionGraph([a_tt], [x])
_ = compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(tt.config.floatX)])
a_tt = tt.tensor(dtype=tt.config.floatX, broadcastable=[False, True])
x = tt.elemwise.DimShuffle([False, True], (0,), inplace=True)(a_tt)
x_fg = theano.gof.FunctionGraph([a_tt], [x])
_ = compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(tt.config.floatX)])
def test_jax_variadic_Scalar():
mu = tt.vector("mu", dtype=tt.config.floatX)
mu.tag.test_value = np.r_[0.1, 1.1].astype(tt.config.floatX)
tau = tt.vector("tau", dtype=tt.config.floatX)
tau.tag.test_value = np.r_[1.0, 2.0].astype(tt.config.floatX)
res = -tau * mu
fgraph = theano.gof.FunctionGraph([mu, tau], [res])
_ = compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
res = -tau * (tau - mu) ** 2
fgraph = theano.gof.FunctionGraph([mu, tau], [res])
_ = compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_jax_logp():
mu = tt.vector("mu")
mu.tag.test_value = np.r_[0.0, 0.0].astype(tt.config.floatX)
tau = tt.vector("tau")
tau.tag.test_value = np.r_[1.0, 1.0].astype(tt.config.floatX)
sigma = tt.vector("sigma")
sigma.tag.test_value = (1.0 / get_test_value(tau)).astype(tt.config.floatX)
value = tt.vector("value")
value.tag.test_value = np.r_[0.1, -10].astype(tt.config.floatX)
logp = (-tau * (value - mu) ** 2 + tt.log(tau / np.pi / 2.0)) / 2.0
conditions = [sigma > 0]
alltrue = tt.all([tt.all(1 * val) for val in conditions])
normal_logp = tt.switch(alltrue, logp, -np.inf)
fgraph = theano.gof.FunctionGraph([mu, tau, sigma, value], [normal_logp])
_ = compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_jax_multioutput():
x = tt.vector("x")
x.tag.test_value = np.r_[1.0, 2.0].astype(tt.config.floatX)
y = tt.vector("y")
y.tag.test_value = np.r_[3.0, 4.0].astype(tt.config.floatX)
w = tt.cosh(x ** 2 + y / 3.0)
v = tt.cosh(x / 3.0 + y ** 2)
fgraph = theano.gof.FunctionGraph([x, y], [w, v])
_ = compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
......@@ -7,12 +7,13 @@ import logging
import warnings
import theano
from theano import gof
import theano.gof.vm
from theano import config
from six import string_types
from theano.compile.function_module import Supervisor
from theano import config, gof
from theano.compile.function_module import Supervisor
from theano.sandbox.jax_linker import JAXLinker
_logger = logging.getLogger("theano.compile.mode")
......@@ -29,6 +30,7 @@ predefined_linkers = {
"cvm": gof.vm.VM_Linker(use_cloop=True), # Use allow_gc Theano flag
"vm_nogc": gof.vm.VM_Linker(allow_gc=False, use_cloop=False),
"cvm_nogc": gof.vm.VM_Linker(allow_gc=False, use_cloop=True),
"jax": JAXLinker(),
}
......
from warnings import warn
from collections.abc import Sequence
from theano.gof.link import (
PerformLinker,
map_storage,
gc_helper,
utils,
add_clear_storage,
Container,
streamline,
)
from theano.gof.graph import Constant
class JAXLinker(PerformLinker):
"""A `Linker` that JIT-compiles NumPy-based operations using JAX.
Attributes
----------
allow_non_jax: bool
A boolean indicating whether or not an exception is thrown when the
graph cannot be JAX compiled (e.g. the graph has an unsupported operator).
If `allow_non_jax` is `True`, the fallback is currently Python compilation.
"""
allow_non_jax = False
def create_jax_thunks(self, compute_map, storage_map):
"""Create a thunk for each output of the `Linker`s `FunctionGraph`.
This is differs from the other thunk-making function in that it only
produces thunks for the `FunctionGraph` output nodes.
Parameters
----------
compute_map: dict
The compute map dictionary.
storage_map: dict
The storage map dictionary.
Returns
-------
thunks: list
A tuple containing the thunks.
output_nodes: list and their
A tuple containing the output nodes.
"""
import jax
from theano.sandbox.jaxify import jax_funcify
output_nodes = [o.owner for o in self.fgraph.outputs]
# Create a JAX-compilable function from our `FunctionGraph`
jaxed_fgraph_outputs = jax_funcify(self.fgraph)
assert len(jaxed_fgraph_outputs) == len(output_nodes)
# I suppose we can consider `Constant`s to be "static" according to
# JAX.
static_argnums = [
n for n, i in enumerate(self.fgraph.inputs) if isinstance(i, Constant)
]
thunk_inputs = [storage_map[n] for n in self.fgraph.inputs]
thunks = []
for node, jax_funcs in zip(output_nodes, jaxed_fgraph_outputs):
thunk_outputs = [storage_map[n] for n in node.outputs]
# JIT-compile the functions
if len(node.outputs) > 1:
assert len(jax_funcs) == len(node.ouptputs)
jax_impl_jits = [
jax.jit(jax_func, static_argnums) for jax_func in jax_funcs
]
else:
assert not isinstance(jax_funcs, Sequence)
jax_impl_jits = [jax.jit(jax_funcs, static_argnums)]
def thunk(
node=node, jax_impl_jits=jax_impl_jits, thunk_outputs=thunk_outputs
):
outputs = [
jax_impl_jit(*[x[0] for x in thunk_inputs])
for jax_impl_jit in jax_impl_jits
]
for o_node, o_storage, o_val in zip(
node.outputs, thunk_outputs, outputs
):
compute_map[o_node][0] = True
if len(o_storage) > 1:
assert len(o_storage) == len(o_val)
for i, o_sub_val in enumerate(o_val):
o_storage[i] = o_sub_val
else:
o_storage[0] = o_val
return outputs
thunk.inputs = thunk_inputs
thunk.outputs = thunk_outputs
thunk.lazy = False
thunks.append(thunk)
return thunks, output_nodes
def make_all(self, input_storage=None, output_storage=None, storage_map=None):
fgraph = self.fgraph
nodes = self.schedule(fgraph)
no_recycling = self.no_recycling
input_storage, output_storage, storage_map = map_storage(
fgraph, nodes, input_storage, output_storage, storage_map
)
compute_map = {}
for k in storage_map:
compute_map[k] = [k.owner is None]
try:
# We need to create thunk functions that will populate the output
# storage arrays with the JAX-computed values.
thunks, nodes = self.create_jax_thunks(compute_map, storage_map)
except NotImplementedError as e:
if not self.allow_non_jax:
raise
warn("JaxLinker could not JAXify graph: {}".format(e))
thunks = []
for node in nodes:
thunk = node.op.make_thunk(
node, storage_map, compute_map, no_recycling, "py"
)
thunk_inputs = [storage_map[v] for v in node.inputs]
thunk_outputs = [storage_map[v] for v in node.outputs]
thunk.inputs = thunk_inputs
thunk.outputs = thunk_outputs
thunks.append(thunk)
computed, last_user = gc_helper(nodes)
if self.allow_gc:
post_thunk_old_storage = []
for node in nodes:
post_thunk_old_storage.append(
[
storage_map[input]
for input in node.inputs
if (input in computed)
and (input not in fgraph.outputs)
and (node == last_user[input])
]
)
else:
post_thunk_old_storage = None
if no_recycling is True:
no_recycling = list(storage_map.values())
no_recycling = utils.difference(no_recycling, input_storage)
else:
no_recycling = [
storage_map[r] for r in no_recycling if r not in fgraph.inputs
]
fn = streamline(
fgraph, thunks, nodes, post_thunk_old_storage, no_recycling=no_recycling
)
fn.allow_gc = self.allow_gc
add_clear_storage(fn, computed, storage_map)
fn.storage_map = storage_map
return (
fn,
[
Container(input, storage)
for input, storage in zip(fgraph.inputs, input_storage)
],
[
Container(output, storage, True)
for output, storage in zip(fgraph.outputs, output_storage)
],
thunks,
nodes,
)
import theano
import jax
import jax.numpy as jnp
from warnings import warn
from functools import partial, update_wrapper, reduce
from collections.abc import Sequence
from functools import singledispatch as dispatch
from theano.gof import FunctionGraph
from theano.ifelse import IfElse
from theano.tensor.subtensor import (
get_idx_list,
Subtensor,
IncSubtensor,
# This is essentially `np.take`
AdvancedSubtensor1,
AdvancedIncSubtensor1,
# Boolean mask indexing and setting
BaseAdvancedSubtensor,
BaseAdvancedIncSubtensor,
)
from theano.scan_module.scan_op import Scan
from theano.scan_module.scan_utils import scan_args as ScanArgs
from theano.tensor.basic import (
TensorFromScalar,
ScalarFromTensor,
AllocEmpty,
Alloc,
Reshape,
Join,
)
from theano.scalar.basic import (
ScalarOp,
Composite,
Cast,
Clip,
)
from theano.tensor.elemwise import Elemwise, CAReduce, DimShuffle
from theano.compile.ops import (
DeepCopyOp,
Shape,
Shape_i,
SpecifyShape,
Rebroadcast,
ViewOp,
)
from theano.tensor.opt import MakeVector
jax.config.update("jax_enable_x64", True)
subtensor_ops = (Subtensor, AdvancedSubtensor1, BaseAdvancedSubtensor)
incsubtensor_ops = (IncSubtensor, AdvancedIncSubtensor1, BaseAdvancedIncSubtensor)
def compose_jax_funcs(out_node, fgraph_inputs, memo=None):
"""Compose JAX implementations of node operations.
Parameters
----------
out_node: Node
The output node.
fgraph_inputs: List[Variable]
The inputs--in a `FunctionGraph` sense--to `out_node`.
memo: Mapping (Optional)
A map from visited nodes to their JAX functions.
Outputs
-------
A `function` object that represents the composed JAX operations and takes
the same form of inputs as `fgraph_inputs`.
"""
if memo is None:
memo = {}
if out_node in memo:
return memo[out_node]
jax_return_func = jax_funcify(out_node.op)
input_funcs = []
for i in out_node.inputs:
if i in fgraph_inputs:
idx = fgraph_inputs.index(i)
def jax_inputs_func(*inputs, i_dtype=i.dtype, idx=idx):
return jnp.array(inputs[idx], dtype=jnp.dtype(i_dtype))
input_f = jax_inputs_func
elif i.owner is None:
def jax_data_func(*inputs, i_dtype=i.dtype, i_data=i.data):
return jnp.array(i_data, dtype=jnp.dtype(i_dtype))
input_f = jax_data_func
else:
input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
input_funcs.append(input_f)
if not isinstance(jax_return_func, Sequence):
jax_return_func = [jax_return_func]
jax_funcs = []
for return_func in jax_return_func:
def jax_func(*inputs):
func_args = [fn(*inputs) for fn in input_funcs]
return return_func(*func_args)
jax_funcs.append(update_wrapper(jax_func, return_func))
if len(out_node.outputs) == 1:
jax_funcs = jax_funcs[0]
memo[out_node] = jax_funcs
return jax_funcs
@dispatch
def jax_funcify(op):
"""Create a JAX "perform" function for a Theano `Variable` and its `Op`."""
raise NotImplementedError("No JAX conversion for the given `Op`: {}".format(op))
@jax_funcify.register(ScalarOp)
def jax_funcify_ScalarOp(op):
func_name = op.nfunc_spec[0]
if "." in func_name:
jnp_func = reduce(getattr, [jax] + func_name.split("."))
else:
jnp_func = getattr(jnp, func_name)
if hasattr(op, "nfunc_variadic"):
# These are special cases that handle invalid arities due to the broken
# Theano `Op` type contract (e.g. binary `Op`s that also function as
# their own variadic counterparts--even when those counterparts already
# exist as independent `Op`s).
jax_variadic_func = getattr(jnp, op.nfunc_variadic)
def elemwise(*args):
if len(args) > op.nfunc_spec[1]:
return jax_variadic_func(
jnp.stack(jnp.broadcast_arrays(*args), axis=0), axis=0
)
else:
return jnp_func(*args)
return elemwise
else:
return jnp_func
@jax_funcify.register(Clip)
def jax_funcify_Clip(op):
return partial(op.impl, None)
@jax_funcify.register(AllocEmpty)
def jax_funcify_AllocEmpty(op):
def allocempty(*shape):
return jnp.empty(shape, dtype=op.dtype)
return allocempty
@jax_funcify.register(Alloc)
def jax_funcify_Alloc(op):
def alloc(x, *shape):
res = jnp.broadcast_to(x, shape)
return res
return alloc
def jnp_safe_copy(x):
try:
res = jnp.copy(x)
except NotImplementedError:
warn("`jnp.copy` is not implemented yet. " "Using the object's `copy` method.")
if hasattr(x, "copy"):
res = jnp.array(x.copy())
else:
warn("Object has no `copy` method: {}".format(x))
res = x
return res
@jax_funcify.register(DeepCopyOp)
def jax_funcify_DeepCopyOp(op):
def deepcopyop(x):
return jnp_safe_copy(x)
return deepcopyop
@jax_funcify.register(Shape)
def jax_funcify_Shape(op):
def shape(x):
return jnp.shape(x)
return shape
@jax_funcify.register(Shape_i)
def jax_funcify_Shape_i(op):
i = op.i
def shape_i(x):
return jnp.shape(x)[i]
return shape_i
@jax_funcify.register(SpecifyShape)
def jax_funcify_SpecifyShape(op):
def specifyshape(x, shape):
assert x.ndim == shape.size
assert jnp.all(x.shape == shape), ("got shape", x.shape, "expected", shape)
return x
return specifyshape
@jax_funcify.register(Rebroadcast)
def jax_funcify_Rebroadcast(op):
op_axis = op.axis
def rebroadcast(x):
for axis, value in op_axis.items():
if value and x.shape[axis] != 1:
raise ValueError(
"Dimension %s in Rebroadcast's input was"
" supposed to be 1 (got %s instead)" % (axis, x.shape[axis])
)
return x
return rebroadcast
@jax_funcify.register(ViewOp)
def jax_funcify_ViewOp(op):
def viewop(x):
return x
return viewop
@jax_funcify.register(Cast)
def jax_funcify_Cast(op):
def cast(x):
return jnp.array(x).astype(op.o_type.dtype)
return cast
@jax_funcify.register(TensorFromScalar)
def jax_funcify_TensorFromScalar(op):
def tensor_from_scalar(x):
return jnp.array(x)
return tensor_from_scalar
@jax_funcify.register(ScalarFromTensor)
def jax_funcify_ScalarFromTensor(op):
def scalar_from_tensor(x):
return jnp.array(x).flatten()[0]
return scalar_from_tensor
@jax_funcify.register(Elemwise)
def jax_funcify_Elemwise(op):
scalar_op = op.scalar_op
return jax_funcify(scalar_op)
@jax_funcify.register(Composite)
def jax_funcify_Composite(op):
jax_impl = jax_funcify(op.fgraph)
return jax_impl
@jax_funcify.register(Scan)
def jax_funcify_Scan(op):
inner_fg = FunctionGraph(op.inputs, op.outputs)
jax_tt_inner_func = jax_funcify(inner_fg)
def scan(*outer_inputs):
scan_args = ScanArgs(
outer_inputs, [None] * op.n_outs, op.inputs, op.outputs, op.info
)
# `outer_inputs` is a list with the following composite form:
# [n_steps]
# + outer_in_seqs
# + outer_in_mit_mot
# + outer_in_mit_sot
# + outer_in_sit_sot
# + outer_in_shared
# + outer_in_nit_sot
# + outer_in_non_seqs
n_steps = scan_args.n_steps
seqs = scan_args.outer_in_seqs
n_non_seqs = len(scan_args.outer_in_non_seqs)
# TODO: sit_sots
mit_sot_in_slices = []
for tap, seq in zip(scan_args.mit_sot_in_slices, scan_args.outer_in_mit_sot):
neg_taps = [abs(t) for t in tap if t < 0]
pos_taps = [abs(t) for t in tap if t > 0]
max_neg = max(neg_taps) if neg_taps else 0
max_pos = max(pos_taps) if pos_taps else 0
init_slice = seq[: max_neg + max_pos]
mit_sot_in_slices.append(init_slice)
init_carry = [mit_sot_in_slices, scan_args.outer_in_non_seqs]
def jax_args_to_inner_scan(op, carry, x):
# `carry` contains all inner-output taps, non_seqs, and shared
# terms
(
inner_in_mit_mot,
inner_in_mit_sot,
inner_in_sit_sot,
inner_in_shared,
inner_in_non_seqs,
) = carry
# `x` contains the in_seqs
inner_in_seqs = x
# `inner_scan_inputs` is a list with the following composite form:
# inner_in_seqs
# + sum(inner_in_mit_mot, [])
# + sum(inner_in_mit_sot, [])
# + inner_in_sit_sot
# + inner_in_shared
# + inner_in_non_seqs
inner_scan_inputs = [
inner_in_seqs,
inner_in_mit_mot,
inner_in_mit_sot,
inner_in_sit_sot,
inner_in_non_seqs,
]
raise NotImplementedError()
return inner_scan_inputs
def inner_scan_outs_to_jax_outs(
op,
old_carry,
inner_scan_outs,
):
# `inner_scan_outs` is a list with the following
# composite form:
# outer_out_mit_mot
# + outer_out_mit_sot
# + outer_out_sit_sot
# + outer_out_nit_sot
# + outer_out_shared
# + cond
(
outer_out_mit_mot,
outer_out_mit_sot,
outer_out_sit_sot,
outer_out_nit_sot,
outer_out_shared,
cond,
) = inner_scan_outs
outer_out_non_seqs = old_carry[:-n_non_seqs]
# This should contain all inner-output taps, non_seqs, and shared
# terms
carry = [
outer_out_mit_mot,
outer_out_mit_sot,
outer_out_sit_sot,
outer_out_shared,
outer_out_non_seqs,
]
# This should contain all inner-outputs that produce
# outer-outputs
y = []
raise NotImplementedError()
return (carry, y)
def jax_inner_func(carry, x):
inner_args = jax_args_to_inner_scan(op, carry, x)
inner_scan_outs = jax_tt_inner_func(*inner_args)
new_carry, y = inner_scan_outs_to_jax_outs(op, inner_scan_outs)
return new_carry, y
return jax.lax.scan(jax_inner_func, init_carry, seqs, length=n_steps)
return scan
@jax_funcify.register(IfElse)
def jax_funcify_IfElse(op):
def ifelse(cond, *args):
if cond:
return args[: op.n_outs]
else:
return args[op.n_outs :]
return ifelse
def convert_indices(indices, entry):
if indices and isinstance(entry, theano.gof.Type):
rval = indices.pop(0)
return rval
elif isinstance(entry, slice):
return slice(
convert_indices(indices, entry.start),
convert_indices(indices, entry.stop),
convert_indices(indices, entry.step),
)
else:
return entry
@jax_funcify.register(Subtensor)
def jax_funcify_Subtensor(op):
idx_list = getattr(op, "idx_list", None)
def subtensor(x, *ilists):
if idx_list:
cdata = get_idx_list((x,) + ilists, idx_list)
else:
cdata = ilists
# breakpoint()
if len(cdata) == 1:
cdata = cdata[0]
return x.__getitem__(cdata)
# return x.take(ilists, axis=0)
return subtensor
_ = [jax_funcify.register(op, jax_funcify_Subtensor) for op in subtensor_ops]
def jax_funcify_IncSubtensor(op):
if getattr(op, "set_instead_of_inc", False):
jax_fn = jax.ops.index_update
else:
jax_fn = jax.ops.index_add
def incsubtensor(x, y, *ilist, jax_fn=jax_fn):
_ilist = list(ilist)
cdata = tuple(convert_indices(_ilist, idx) for idx in op.idx_list)
if len(cdata) == 1:
cdata = cdata[0]
return jax_fn(x, cdata, y)
return incsubtensor
_ = [jax_funcify.register(op, jax_funcify_IncSubtensor) for op in incsubtensor_ops]
@jax_funcify.register(FunctionGraph)
def jax_funcify_FunctionGraph(fgraph):
out_nodes = [r.owner for r in fgraph.outputs if r.owner is not None]
jax_funcs = [compose_jax_funcs(o, fgraph.inputs) for o in out_nodes]
return jax_funcs
@jax_funcify.register(CAReduce)
def jax_funcify_CAReduce(op):
def careduce(x):
axis = op.axis
if axis is None:
axis = list(range(x.ndim))
to_reduce = reversed(sorted(axis))
if hasattr(op, "acc_dtype") and op.acc_dtype is not None:
acc_dtype = op.acc_dtype
else:
acc_dtype = x.dtype.type
if to_reduce:
if getattr(op.scalar_op, "name", None):
jax_op = getattr(jax.lax, op.scalar_op.name)
elif getattr(op.scalar_op, "nfunc_spec", None):
# In this case, we need to use the `jax.lax` function (if there
# is one), and not the `jnp` version.
jax_op = getattr(jax.lax, op.scalar_op.nfunc_spec[0])
init_value = jnp.array(op.scalar_op.identity, dtype=acc_dtype)
return jax.lax.reduce(x, init_value, jax_op, to_reduce).astype(acc_dtype)
else:
return x
return careduce
@jax_funcify.register(MakeVector)
def jax_funcify_MakeVector(op):
def makevector(*x):
return jnp.array(x, dtype=op.dtype)
return makevector
@jax_funcify.register(Reshape)
def jax_funcify_Reshape(op):
def reshape(x, shape):
return jnp.reshape(x, shape)
return reshape
@jax_funcify.register(DimShuffle)
def jax_funcify_DimShuffle(op):
def dimshuffle(x):
res = jnp.transpose(x, op.shuffle + op.drop)
shape = list(res.shape[: len(op.shuffle)])
for augm in op.augment:
shape.insert(augm, 1)
res = jnp.reshape(res, shape)
if not op.inplace:
res = jnp_safe_copy(res)
return res
return dimshuffle
@jax_funcify.register(Join)
def jax_funcify_Join(op):
def join(axis, *tensors):
view = op.view
if (view != -1) and all(
[
tensor.shape[axis] == 0
for tensor in tensors[0:view] + tensors[view + 1 :]
]
):
return tensors[view]
else:
ndim = tensors[0].ndim
if axis < -ndim:
raise IndexError("Join axis %d out of bounds [0, %d)" % (axis, ndim))
return jnp.concatenate(tensors, axis=axis)
return join
......@@ -2988,6 +2988,8 @@ class Inv(UnaryScalarOp):
"""
nfunc_spec = ("reciprocal", 1, 1)
def impl(self, x):
return np.float32(1.0) / x
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论