Unverified 提交 98875d1e authored 作者: Thomas Wiecki's avatar Thomas Wiecki 提交者: GitHub

Merge pull request #21 from brandonwillard/jax-linker

Introduce a JAX Linker class
...@@ -60,6 +60,7 @@ install: ...@@ -60,6 +60,7 @@ install:
- conda create --yes -q -n pyenv python=$TRAVIS_PYTHON_VERSION - conda create --yes -q -n pyenv python=$TRAVIS_PYTHON_VERSION
- conda activate pyenv - conda activate pyenv
- conda install --yes -q mkl numpy scipy pip mkl-service graphviz cython libgpuarray pygpu - conda install --yes -q mkl numpy scipy pip mkl-service graphviz cython libgpuarray pygpu
- if [[ "$TRAVIS_PYTHON_VERSION" != "3.6" ]]; then conda install --yes -q -c conda-forge jax jaxlib; fi
- pip install -q -r requirements.txt - pip install -q -r requirements.txt
- conda list && pip freeze - conda list && pip freeze
- python -c 'import theano; print(theano.config.__str__(print_doc=False))' - python -c 'import theano; print(theano.config.__str__(print_doc=False))'
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
flake8 flake8
pep8 pep8
pyflakes pyflakes
black==20.8b1 black==20.8b1; platform.python_implementation!='PyPy'
pytest-cov>=2.6.1 pytest-cov>=2.6.1
coverage>=5.1 coverage>=5.1
pytest pytest
...@@ -10,3 +10,5 @@ coveralls ...@@ -10,3 +10,5 @@ coveralls
cython cython
sympy sympy
versioneer versioneer
jax; python_version > '3.6'
jaxlib; python_version > '3.6'
import pytest
import numpy as np
import theano
import theano.tensor as tt
jax = pytest.importorskip("jax")
from theano.gof.op import get_test_value # noqa: E402
@pytest.fixture(scope="module", autouse=True)
def set_theano_flags():
with theano.change_flags(cxx="", compute_test_value="warn"):
yield
def compare_jax_and_py(fgraph, inputs, cmp_fn=np.allclose):
# jax_mode = theano.compile.Mode(linker="jax")
jax_mode = "JAX"
theano_jax_fn = theano.function(fgraph.inputs, fgraph.outputs, mode=jax_mode)
jax_res = theano_jax_fn(*inputs)
if isinstance(jax_res, list):
assert all(isinstance(res, jax.interpreters.xla.DeviceArray) for res in jax_res)
else:
assert isinstance(jax_res, jax.interpreters.xla.DeviceArray)
py_mode = theano.compile.Mode(linker="py")
theano_py_fn = theano.function(fgraph.inputs, fgraph.outputs, mode=py_mode)
py_res = theano_py_fn(*inputs)
assert cmp_fn(jax_res, py_res)
return jax_res
def test_jax_Alloc():
x = tt.alloc(0.0, 2, 3)
x_fg = theano.gof.FunctionGraph([], [x])
(jax_res,) = compare_jax_and_py(x_fg, [])
assert jax_res.shape == (2, 3)
x = tt.alloc(1.1, 2, 3)
x_fg = theano.gof.FunctionGraph([], [x])
compare_jax_and_py(x_fg, [])
x = theano.tensor.basic.AllocEmpty("float32")(2, 3)
x_fg = theano.gof.FunctionGraph([], [x])
def compare_shape_dtype(x, y):
(x,) = x
(y,) = y
return x.shape == y.shape and x.dtype == y.dtype
(jax_res,) = compare_jax_and_py(x_fg, [], cmp_fn=compare_shape_dtype)
a = tt.scalar("a")
x = tt.alloc(a, 20)
x_fg = theano.gof.FunctionGraph([a], [x])
(jax_res,) = compare_jax_and_py(x_fg, [10.0])
a = tt.vector("a")
x = tt.alloc(a, 20, 10)
x_fg = theano.gof.FunctionGraph([a], [x])
(jax_res,) = compare_jax_and_py(x_fg, [np.ones(10, dtype=tt.config.floatX)])
def test_jax_compile_ops():
x = theano.compile.ops.DeepCopyOp()(tt.as_tensor_variable(1.1))
x_fg = theano.gof.FunctionGraph([], [x])
(jax_res,) = compare_jax_and_py(x_fg, [])
x_np = np.zeros((20, 3))
x = theano.compile.ops.Shape()(tt.as_tensor_variable(x_np))
x_fg = theano.gof.FunctionGraph([], [x])
(jax_res,) = compare_jax_and_py(x_fg, [])
x = theano.compile.ops.Shape_i(1)(tt.as_tensor_variable(x_np))
x_fg = theano.gof.FunctionGraph([], [x])
(jax_res,) = compare_jax_and_py(x_fg, [])
x = theano.compile.ops.SpecifyShape()(tt.as_tensor_variable(x_np), (20, 3))
x_fg = theano.gof.FunctionGraph([], [x])
(jax_res,) = compare_jax_and_py(x_fg, [])
with theano.change_flags(compute_test_value="off"):
x = theano.compile.ops.SpecifyShape()(tt.as_tensor_variable(x_np), (2, 3))
x_fg = theano.gof.FunctionGraph([], [x])
with pytest.raises(AssertionError):
(jax_res,) = compare_jax_and_py(x_fg, [])
x_np = np.zeros((20, 1, 1))
x = theano.compile.ops.Rebroadcast((0, False), (1, True), (2, False))(
tt.as_tensor_variable(x_np)
)
x_fg = theano.gof.FunctionGraph([], [x])
(jax_res,) = compare_jax_and_py(x_fg, [])
with theano.change_flags(compute_test_value="off"):
x = theano.compile.ops.Rebroadcast((0, True), (1, False), (2, False))(
tt.as_tensor_variable(x_np)
)
x_fg = theano.gof.FunctionGraph([], [x])
with pytest.raises(ValueError):
(jax_res,) = compare_jax_and_py(x_fg, [])
x = theano.compile.ops.ViewOp()(tt.as_tensor_variable(x_np))
x_fg = theano.gof.FunctionGraph([], [x])
(jax_res,) = compare_jax_and_py(x_fg, [])
def test_jax_basic():
x = tt.matrix("x")
y = tt.matrix("y")
# `ScalarOp`
z = tt.cosh(x ** 2 + y / 3.0)
# `[Inc]Subtensor`
out = tt.set_subtensor(z[0], -10.0)
out = tt.inc_subtensor(out[0, 1], 2.0)
out = out[:5, :3]
out_fg = theano.gof.FunctionGraph([x, y], [out])
test_input_vals = [
np.tile(np.arange(10), (10, 1)).astype(tt.config.floatX),
np.tile(np.arange(10, 20), (10, 1)).astype(tt.config.floatX),
]
(jax_res,) = compare_jax_and_py(out_fg, test_input_vals)
# Confirm that the `Subtensor` slice operations are correct
assert jax_res.shape == (5, 3)
# Confirm that the `IncSubtensor` operations are correct
assert jax_res[0, 0] == -10.0
assert jax_res[0, 1] == -8.0
out = tt.clip(x, y, 5)
out_fg = theano.gof.FunctionGraph([x, y], [out])
(jax_res,) = compare_jax_and_py(out_fg, test_input_vals)
@pytest.mark.skip(reason="Not fully implemented, yet.")
def test_jax_scan():
theano.config.compute_test_value = "raise"
a_tt = tt.scalar("a")
a_tt.tag.test_value = 3.0
def input_step_fn(y_tm1, y_tm2, a):
y_tm1.name = "y_tm1"
y_tm2.name = "y_tm2"
res = (y_tm1 + y_tm2) * a
res.name = "y_t"
return res
y_scan_tt, _ = theano.scan(
fn=input_step_fn,
outputs_info=[
{
"initial": tt.as_tensor_variable(
np.r_[-1.0, 0.0].astype(tt.config.floatX)
),
"taps": [-1, -2],
},
],
non_sequences=[a_tt],
n_steps=10,
name="y_scan",
)
y_scan_tt.name = "y"
y_scan_tt.owner.inputs[0].name = "y_all"
theano_scan_fn = theano.function([], y_scan_tt, givens={a_tt: 3.0})
theano_res = theano_scan_fn()
#
# The equivalent JAX `scan`:
#
import jax
import jax.numpy as jnp
def jax_inner_scan(carry, x):
(y_tm1, y_tm2), a = carry
res = (y_tm1 + y_tm2) * a
return [jnp.array([res, y_tm1]), a], res
init_carry = [np.r_[0.0, -1.0].astype(tt.config.floatX), 3.0]
tmp, jax_res = jax.lax.scan(jax_inner_scan, init_carry, None, length=10)
assert np.allclose(jax_res, theano_res)
out_fg = theano.gof.FunctionGraph([a_tt], [y_scan_tt])
test_input_vals = [np.array(10.0).astype(tt.config.floatX)]
(jax_res,) = compare_jax_and_py(out_fg, test_input_vals)
assert False
def test_jax_Subtensors():
# Basic indices
x_tt = tt.arange(3 * 4 * 5).reshape((3, 4, 5))
out_tt = x_tt[1, 2, 0]
out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, [])
out_tt = x_tt[1:2, 1, :]
out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, [])
# Boolean indices
out_tt = x_tt[x_tt < 0]
out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, [])
# Advanced indexing
out_tt = x_tt[[1, 2]]
out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, [])
out_tt = x_tt[[1, 2], [2, 3]]
out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, [])
# Advanced and basic indexing
out_tt = x_tt[[1, 2], :]
out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, [])
out_tt = x_tt[[1, 2], :, [3, 4]]
out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, [])
def test_jax_IncSubtensor():
x_np = np.empty((3, 4, 5), dtype=tt.config.floatX)
x_tt = tt.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(tt.config.floatX)
# "Set" basic indices
st_tt = tt.as_tensor_variable(np.array(-10.0, dtype=tt.config.floatX))
out_tt = tt.set_subtensor(x_tt[1, 2, 3], st_tt)
out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, [])
st_tt = tt.as_tensor_variable(np.r_[-1.0, 0.0].astype(tt.config.floatX))
out_tt = tt.set_subtensor(x_tt[:2, 0, 0], st_tt)
out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, [])
out_tt = tt.set_subtensor(x_tt[0, 1:3, 0], st_tt)
out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, [])
# "Set" advanced indices
st_tt = tt.as_tensor_variable(np.r_[-1.0, 0.0].astype(tt.config.floatX))
out_tt = tt.set_subtensor(x_tt[[0, 2], 0, 0], st_tt)
out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, [])
st_tt = tt.as_tensor_variable(x_np[[0, 2], 0, :3])
out_tt = tt.set_subtensor(x_tt[[0, 2], 0, :3], st_tt)
out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, [])
# "Set" boolean indices
mask_tt = tt.as_tensor_variable(x_np) > 0
out_tt = tt.set_subtensor(x_tt[mask_tt], 0.0)
out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, [])
# "Increment" basic indices
st_tt = tt.as_tensor_variable(np.array(-10.0, dtype=tt.config.floatX))
out_tt = tt.inc_subtensor(x_tt[1, 2, 3], st_tt)
out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, [])
st_tt = tt.as_tensor_variable(np.r_[-1.0, 0.0].astype(tt.config.floatX))
out_tt = tt.inc_subtensor(x_tt[:2, 0, 0], st_tt)
out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, [])
out_tt = tt.set_subtensor(x_tt[0, 1:3, 0], st_tt)
out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, [])
# "Increment" advanced indices
st_tt = tt.as_tensor_variable(np.r_[-1.0, 0.0].astype(tt.config.floatX))
out_tt = tt.inc_subtensor(x_tt[[0, 2], 0, 0], st_tt)
out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, [])
st_tt = tt.as_tensor_variable(x_np[[0, 2], 0, :3])
out_tt = tt.inc_subtensor(x_tt[[0, 2], 0, :3], st_tt)
out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, [])
# "Increment" boolean indices
mask_tt = tt.as_tensor_variable(x_np) > 0
out_tt = tt.set_subtensor(x_tt[mask_tt], 1.0)
out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, [])
def test_jax_ifelse():
true_vals = np.r_[1, 2, 3]
false_vals = np.r_[-1, -2, -3]
x = theano.ifelse.ifelse(np.array(True), true_vals, false_vals)
x_fg = theano.gof.FunctionGraph([], [x])
(jax_res,) = compare_jax_and_py(x_fg, [])
x = theano.ifelse.ifelse(np.array(False), true_vals, false_vals)
x_fg = theano.gof.FunctionGraph([], [x])
(jax_res,) = compare_jax_and_py(x_fg, [])
def test_jax_CAReduce():
a_tt = tt.vector("a")
a_tt.tag.test_value = np.r_[1, 2, 3].astype(tt.config.floatX)
x = tt.sum(a_tt, axis=None)
x_fg = theano.gof.FunctionGraph([a_tt], [x])
_ = compare_jax_and_py(x_fg, [np.r_[1, 2, 3].astype(tt.config.floatX)])
a_tt = tt.matrix("a")
a_tt.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(tt.config.floatX)
x = tt.sum(a_tt, axis=0)
x_fg = theano.gof.FunctionGraph([a_tt], [x])
_ = compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(tt.config.floatX)])
x = tt.sum(a_tt, axis=1)
x_fg = theano.gof.FunctionGraph([a_tt], [x])
_ = compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(tt.config.floatX)])
a_tt = tt.matrix("a")
a_tt.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(tt.config.floatX)
x = tt.prod(a_tt, axis=0)
x_fg = theano.gof.FunctionGraph([a_tt], [x])
_ = compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(tt.config.floatX)])
x = tt.all(a_tt)
x_fg = theano.gof.FunctionGraph([a_tt], [x])
_ = compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(tt.config.floatX)])
def test_jax_MakeVector():
x = tt.opt.make_vector(1, 2, 3)
x_fg = theano.gof.FunctionGraph([], [x])
_ = compare_jax_and_py(x_fg, [])
def test_jax_Reshape():
a_tt = tt.vector("a")
x = tt.basic.reshape(a_tt, (2, 2))
x_fg = theano.gof.FunctionGraph([a_tt], [x])
_ = compare_jax_and_py(
x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(theano.config.floatX)]
)
def test_jax_Reshape_omnistaging():
# Test breaking "omnistaging" changes in JAX.
# See https://github.com/tensorflow/probability/commit/782d0c64eb774b9aac54a1c8488e4f1f96fbbc68
a_tt = tt.vector("a")
x = tt.basic.reshape(a_tt, (a_tt.shape[0] // 2, a_tt.shape[0] // 3))
x_fg = theano.gof.FunctionGraph([a_tt], [x])
_ = compare_jax_and_py(x_fg, [np.empty((6,)).astype(theano.config.floatX)])
def test_jax_Dimshuffle():
a_tt = tt.matrix("a")
x = a_tt.T
x_fg = theano.gof.FunctionGraph([a_tt], [x])
_ = compare_jax_and_py(
x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(tt.config.floatX)]
)
x = a_tt.dimshuffle([0, 1, "x"])
x_fg = theano.gof.FunctionGraph([a_tt], [x])
_ = compare_jax_and_py(
x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(tt.config.floatX)]
)
a_tt = tt.tensor(dtype=tt.config.floatX, broadcastable=[False, True])
x = a_tt.dimshuffle((0,))
x_fg = theano.gof.FunctionGraph([a_tt], [x])
_ = compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(tt.config.floatX)])
a_tt = tt.tensor(dtype=tt.config.floatX, broadcastable=[False, True])
x = tt.elemwise.DimShuffle([False, True], (0,), inplace=True)(a_tt)
x_fg = theano.gof.FunctionGraph([a_tt], [x])
_ = compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(tt.config.floatX)])
def test_jax_variadic_Scalar():
mu = tt.vector("mu", dtype=tt.config.floatX)
mu.tag.test_value = np.r_[0.1, 1.1].astype(tt.config.floatX)
tau = tt.vector("tau", dtype=tt.config.floatX)
tau.tag.test_value = np.r_[1.0, 2.0].astype(tt.config.floatX)
res = -tau * mu
fgraph = theano.gof.FunctionGraph([mu, tau], [res])
_ = compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
res = -tau * (tau - mu) ** 2
fgraph = theano.gof.FunctionGraph([mu, tau], [res])
_ = compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_jax_logp():
mu = tt.vector("mu")
mu.tag.test_value = np.r_[0.0, 0.0].astype(tt.config.floatX)
tau = tt.vector("tau")
tau.tag.test_value = np.r_[1.0, 1.0].astype(tt.config.floatX)
sigma = tt.vector("sigma")
sigma.tag.test_value = (1.0 / get_test_value(tau)).astype(tt.config.floatX)
value = tt.vector("value")
value.tag.test_value = np.r_[0.1, -10].astype(tt.config.floatX)
logp = (-tau * (value - mu) ** 2 + tt.log(tau / np.pi / 2.0)) / 2.0
conditions = [sigma > 0]
alltrue = tt.all([tt.all(1 * val) for val in conditions])
normal_logp = tt.switch(alltrue, logp, -np.inf)
fgraph = theano.gof.FunctionGraph([mu, tau, sigma, value], [normal_logp])
_ = compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_jax_multioutput():
x = tt.vector("x")
x.tag.test_value = np.r_[1.0, 2.0].astype(tt.config.floatX)
y = tt.vector("y")
y.tag.test_value = np.r_[3.0, 4.0].astype(tt.config.floatX)
w = tt.cosh(x ** 2 + y / 3.0)
v = tt.cosh(x / 3.0 + y ** 2)
fgraph = theano.gof.FunctionGraph([x, y], [w, v])
_ = compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_nnet():
x = tt.vector("x")
x.tag.test_value = np.r_[1.0, 2.0].astype(tt.config.floatX)
out = tt.nnet.sigmoid(x)
fgraph = theano.gof.FunctionGraph([x], [out])
_ = compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = tt.nnet.ultra_fast_sigmoid(x)
fgraph = theano.gof.FunctionGraph([x], [out])
_ = compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = tt.nnet.softplus(x)
fgraph = theano.gof.FunctionGraph([x], [out])
_ = compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_tensor_basics():
y = tt.vector("y")
y.tag.test_value = np.r_[1.0, 2.0].astype(theano.config.floatX)
x = tt.vector("x")
x.tag.test_value = np.r_[3.0, 4.0].astype(theano.config.floatX)
A = tt.matrix("A")
A.tag.test_value = np.empty((2, 2), dtype=theano.config.floatX)
alpha = tt.scalar("alpha")
alpha.tag.test_value = np.array(3.0, dtype=theano.config.floatX)
beta = tt.scalar("beta")
beta.tag.test_value = np.array(5.0, dtype=theano.config.floatX)
# This should be converted into a `Gemv` `Op` when the non-JAX compatible
# optimizations are turned on; however, when using JAX mode, it should
# leave the expression alone.
out = y.dot(alpha * A).dot(x) + beta * y
fgraph = theano.gof.FunctionGraph([y, x, A, alpha, beta], [out])
_ = compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = tt.maximum(y, x)
fgraph = theano.gof.FunctionGraph([y, x], [out])
_ = compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = tt.max(y)
fgraph = theano.gof.FunctionGraph([y], [out])
_ = compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
@pytest.mark.xfail(reason="jax.numpy.arange requires concrete inputs")
def test_arange():
a = tt.scalar("a")
a.tag.test_value = 10
out = tt.arange(a)
fgraph = theano.gof.FunctionGraph([a], [out])
_ = compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
...@@ -7,12 +7,13 @@ import logging ...@@ -7,12 +7,13 @@ import logging
import warnings import warnings
import theano import theano
from theano import gof
import theano.gof.vm import theano.gof.vm
from theano import config
from six import string_types from six import string_types
from theano.compile.function_module import Supervisor
from theano import config, gof
from theano.compile.function_module import Supervisor
from theano.sandbox.jax_linker import JAXLinker
_logger = logging.getLogger("theano.compile.mode") _logger = logging.getLogger("theano.compile.mode")
...@@ -29,6 +30,7 @@ predefined_linkers = { ...@@ -29,6 +30,7 @@ predefined_linkers = {
"cvm": gof.vm.VM_Linker(use_cloop=True), # Use allow_gc Theano flag "cvm": gof.vm.VM_Linker(use_cloop=True), # Use allow_gc Theano flag
"vm_nogc": gof.vm.VM_Linker(allow_gc=False, use_cloop=False), "vm_nogc": gof.vm.VM_Linker(allow_gc=False, use_cloop=False),
"cvm_nogc": gof.vm.VM_Linker(allow_gc=False, use_cloop=True), "cvm_nogc": gof.vm.VM_Linker(allow_gc=False, use_cloop=True),
"jax": JAXLinker(),
} }
...@@ -411,9 +413,15 @@ if theano.config.cxx: ...@@ -411,9 +413,15 @@ if theano.config.cxx:
else: else:
FAST_RUN = Mode("vm", "fast_run") FAST_RUN = Mode("vm", "fast_run")
JAX = Mode(
JAXLinker(), gof.Query(include=["fast_run"], exclude=["cxx_only", "BlasOpt"])
)
predefined_modes = { predefined_modes = {
"FAST_COMPILE": FAST_COMPILE, "FAST_COMPILE": FAST_COMPILE,
"FAST_RUN": FAST_RUN, "FAST_RUN": FAST_RUN,
"JAX": JAX,
} }
instantiated_default_mode = None instantiated_default_mode = None
......
...@@ -596,14 +596,16 @@ AddConfigVar( ...@@ -596,14 +596,16 @@ AddConfigVar(
# Also, please be careful not to modify the first item in the enum when adding # Also, please be careful not to modify the first item in the enum when adding
# new modes, since it is the default mode. # new modes, since it is the default mode.
def filter_mode(val): def filter_mode(val):
if val in [ if (
val
in [
"Mode", "Mode",
"DebugMode", "DebugMode",
"FAST_RUN",
"NanGuardMode", "NanGuardMode",
"FAST_COMPILE",
"DEBUG_MODE", "DEBUG_MODE",
]: ]
or val in theano.compile.mode.predefined_modes
):
return val return val
# This can be executed before Theano is completly imported, so # This can be executed before Theano is completly imported, so
# theano.Mode is not always available. # theano.Mode is not always available.
......
from warnings import warn
from collections.abc import Sequence
from theano.gof.link import (
PerformLinker,
map_storage,
gc_helper,
utils,
add_clear_storage,
Container,
streamline,
)
from theano.gof.graph import Constant
class JAXLinker(PerformLinker):
"""A `Linker` that JIT-compiles NumPy-based operations using JAX.
Attributes
----------
allow_non_jax: bool
A boolean indicating whether or not an exception is thrown when the
graph cannot be JAX compiled (e.g. the graph has an unsupported operator).
If `allow_non_jax` is `True`, the fallback is currently Python compilation.
"""
allow_non_jax = False
def create_jax_thunks(self, compute_map, storage_map):
"""Create a thunk for each output of the `Linker`s `FunctionGraph`.
This is differs from the other thunk-making function in that it only
produces thunks for the `FunctionGraph` output nodes.
Parameters
----------
compute_map: dict
The compute map dictionary.
storage_map: dict
The storage map dictionary.
Returns
-------
thunks: list
A tuple containing the thunks.
output_nodes: list and their
A tuple containing the output nodes.
"""
import jax
from theano.sandbox.jaxify import jax_funcify
output_nodes = [o.owner for o in self.fgraph.outputs]
# Create a JAX-compilable function from our `FunctionGraph`
jaxed_fgraph_outputs = jax_funcify(self.fgraph)
assert len(jaxed_fgraph_outputs) == len(output_nodes)
# I suppose we can consider `Constant`s to be "static" according to
# JAX.
static_argnums = [
n for n, i in enumerate(self.fgraph.inputs) if isinstance(i, Constant)
]
thunk_inputs = [storage_map[n] for n in self.fgraph.inputs]
thunks = []
for node, jax_funcs in zip(output_nodes, jaxed_fgraph_outputs):
thunk_outputs = [storage_map[n] for n in node.outputs]
# JIT-compile the functions
if len(node.outputs) > 1:
assert len(jax_funcs) == len(node.ouptputs)
jax_impl_jits = [
jax.jit(jax_func, static_argnums) for jax_func in jax_funcs
]
else:
assert not isinstance(jax_funcs, Sequence)
jax_impl_jits = [jax.jit(jax_funcs, static_argnums)]
def thunk(
node=node, jax_impl_jits=jax_impl_jits, thunk_outputs=thunk_outputs
):
outputs = [
jax_impl_jit(*[x[0] for x in thunk_inputs])
for jax_impl_jit in jax_impl_jits
]
for o_node, o_storage, o_val in zip(
node.outputs, thunk_outputs, outputs
):
compute_map[o_node][0] = True
if len(o_storage) > 1:
assert len(o_storage) == len(o_val)
for i, o_sub_val in enumerate(o_val):
o_storage[i] = o_sub_val
else:
o_storage[0] = o_val
return outputs
thunk.inputs = thunk_inputs
thunk.outputs = thunk_outputs
thunk.lazy = False
thunks.append(thunk)
return thunks, output_nodes
def make_all(self, input_storage=None, output_storage=None, storage_map=None):
fgraph = self.fgraph
nodes = self.schedule(fgraph)
no_recycling = self.no_recycling
input_storage, output_storage, storage_map = map_storage(
fgraph, nodes, input_storage, output_storage, storage_map
)
compute_map = {}
for k in storage_map:
compute_map[k] = [k.owner is None]
try:
# We need to create thunk functions that will populate the output
# storage arrays with the JAX-computed values.
thunks, nodes = self.create_jax_thunks(compute_map, storage_map)
except NotImplementedError as e:
if not self.allow_non_jax:
raise
warn("JaxLinker could not JAXify graph: {}".format(e))
thunks = []
for node in nodes:
thunk = node.op.make_thunk(
node, storage_map, compute_map, no_recycling, "py"
)
thunk_inputs = [storage_map[v] for v in node.inputs]
thunk_outputs = [storage_map[v] for v in node.outputs]
thunk.inputs = thunk_inputs
thunk.outputs = thunk_outputs
thunks.append(thunk)
computed, last_user = gc_helper(nodes)
if self.allow_gc:
post_thunk_old_storage = []
for node in nodes:
post_thunk_old_storage.append(
[
storage_map[input]
for input in node.inputs
if (input in computed)
and (input not in fgraph.outputs)
and (node == last_user[input])
]
)
else:
post_thunk_old_storage = None
if no_recycling is True:
no_recycling = list(storage_map.values())
no_recycling = utils.difference(no_recycling, input_storage)
else:
no_recycling = [
storage_map[r] for r in no_recycling if r not in fgraph.inputs
]
fn = streamline(
fgraph, thunks, nodes, post_thunk_old_storage, no_recycling=no_recycling
)
fn.allow_gc = self.allow_gc
add_clear_storage(fn, computed, storage_map)
fn.storage_map = storage_map
return (
fn,
[
Container(input, storage)
for input, storage in zip(fgraph.inputs, input_storage)
],
[
Container(output, storage, True)
for output, storage in zip(fgraph.outputs, output_storage)
],
thunks,
nodes,
)
import theano
import jax
import jax.numpy as jnp
from warnings import warn
from functools import partial, update_wrapper, reduce
from collections.abc import Sequence
from functools import singledispatch as dispatch
from theano.gof import FunctionGraph
from theano.ifelse import IfElse
from theano.tensor.subtensor import (
get_idx_list,
Subtensor,
IncSubtensor,
# This is essentially `np.take`
AdvancedSubtensor1,
AdvancedIncSubtensor1,
# Boolean mask indexing and setting
BaseAdvancedSubtensor,
BaseAdvancedIncSubtensor,
)
from theano.scan_module.scan_op import Scan
from theano.scan_module.scan_utils import scan_args as ScanArgs
from theano.tensor.basic import (
Dot,
ARange,
TensorFromScalar,
ScalarFromTensor,
AllocEmpty,
Alloc,
Reshape,
Join,
)
from theano.scalar.basic import (
ScalarOp,
Composite,
Cast,
Clip,
)
from theano.tensor.elemwise import Elemwise, CAReduce, DimShuffle
from theano.compile.ops import (
DeepCopyOp,
Shape,
Shape_i,
SpecifyShape,
Rebroadcast,
ViewOp,
)
from theano.tensor.opt import MakeVector
from theano.tensor.nnet.sigm import ScalarSoftplus
# XXX: Enabling this will break some shape-based functionality, and severely
# limit the types of graphs that can be converted.
# See https://github.com/google/jax/blob/4d556837cc9003492f674c012689efc3d68fdf5f/design_notes/omnistaging.md
jax.config.disable_omnistaging()
jax.config.update("jax_enable_x64", True)
subtensor_ops = (Subtensor, AdvancedSubtensor1, BaseAdvancedSubtensor)
incsubtensor_ops = (IncSubtensor, AdvancedIncSubtensor1, BaseAdvancedIncSubtensor)
def compose_jax_funcs(out_node, fgraph_inputs, memo=None):
"""Compose JAX implementations of node operations.
Parameters
----------
out_node: Node
The output node.
fgraph_inputs: List[Variable]
The inputs--in a `FunctionGraph` sense--to `out_node`.
memo: Mapping (Optional)
A map from visited nodes to their JAX functions.
Outputs
-------
A `function` object that represents the composed JAX operations and takes
the same form of inputs as `fgraph_inputs`.
"""
if memo is None:
memo = {}
if out_node in memo:
return memo[out_node]
jax_return_func = jax_funcify(out_node.op)
input_funcs = []
for i in out_node.inputs:
if i in fgraph_inputs:
idx = fgraph_inputs.index(i)
def jax_inputs_func(*inputs, i_dtype=i.dtype, idx=idx):
return jnp.array(inputs[idx], dtype=jnp.dtype(i_dtype))
input_f = jax_inputs_func
elif i.owner is None:
def jax_data_func(*inputs, i_dtype=i.dtype, i_data=i.data):
return jnp.array(i_data, dtype=jnp.dtype(i_dtype))
input_f = jax_data_func
else:
input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
input_funcs.append(input_f)
if not isinstance(jax_return_func, Sequence):
jax_return_func = [jax_return_func]
jax_funcs = []
for return_func in jax_return_func:
def jax_func(*inputs):
func_args = [fn(*inputs) for fn in input_funcs]
return return_func(*func_args)
jax_funcs.append(update_wrapper(jax_func, return_func))
if len(out_node.outputs) == 1:
jax_funcs = jax_funcs[0]
memo[out_node] = jax_funcs
return jax_funcs
@dispatch
def jax_funcify(op):
"""Create a JAX "perform" function for a Theano `Variable` and its `Op`."""
raise NotImplementedError("No JAX conversion for the given `Op`: {}".format(op))
@jax_funcify.register(ScalarOp)
def jax_funcify_ScalarOp(op):
func_name = op.nfunc_spec[0]
if "." in func_name:
jnp_func = reduce(getattr, [jax] + func_name.split("."))
else:
jnp_func = getattr(jnp, func_name)
if hasattr(op, "nfunc_variadic"):
# These are special cases that handle invalid arities due to the broken
# Theano `Op` type contract (e.g. binary `Op`s that also function as
# their own variadic counterparts--even when those counterparts already
# exist as independent `Op`s).
jax_variadic_func = getattr(jnp, op.nfunc_variadic)
def elemwise(*args):
if len(args) > op.nfunc_spec[1]:
return jax_variadic_func(
jnp.stack(jnp.broadcast_arrays(*args), axis=0), axis=0
)
else:
return jnp_func(*args)
return elemwise
else:
return jnp_func
@jax_funcify.register(Clip)
def jax_funcify_Clip(op):
def clip(x, min, max):
return jnp.where(x < min, min, jnp.where(x > max, max, x))
return clip
@jax_funcify.register(ScalarSoftplus)
def jax_funcify_ScalarSoftplus(op):
def scalarsoftplus(x):
return jnp.where(x < -30.0, 0.0, jnp.where(x > 30.0, x, jnp.log1p(jnp.exp(x))))
return scalarsoftplus
@jax_funcify.register(AllocEmpty)
def jax_funcify_AllocEmpty(op):
def allocempty(*shape):
return jnp.empty(shape, dtype=op.dtype)
return allocempty
@jax_funcify.register(Alloc)
def jax_funcify_Alloc(op):
def alloc(x, *shape):
res = jnp.broadcast_to(x, shape)
return res
return alloc
@jax_funcify.register(Dot)
def jax_funcify_Dot(op):
def dot(x, y):
return jnp.dot(x, y)
return dot
@jax_funcify.register(ARange)
def jax_funcify_ARange(op):
# XXX: This currently requires concrete arguments.
def arange(start, stop, step):
return jnp.arange(start, stop, step, dtype=op.dtype)
return arange
def jnp_safe_copy(x):
try:
res = jnp.copy(x)
except NotImplementedError:
warn("`jnp.copy` is not implemented yet. " "Using the object's `copy` method.")
if hasattr(x, "copy"):
res = jnp.array(x.copy())
else:
warn("Object has no `copy` method: {}".format(x))
res = x
return res
@jax_funcify.register(DeepCopyOp)
def jax_funcify_DeepCopyOp(op):
def deepcopyop(x):
return jnp_safe_copy(x)
return deepcopyop
@jax_funcify.register(Shape)
def jax_funcify_Shape(op):
def shape(x):
return jnp.shape(x)
return shape
@jax_funcify.register(Shape_i)
def jax_funcify_Shape_i(op):
i = op.i
def shape_i(x):
return jnp.shape(x)[i]
return shape_i
@jax_funcify.register(SpecifyShape)
def jax_funcify_SpecifyShape(op):
def specifyshape(x, shape):
assert x.ndim == shape.size
assert jnp.all(x.shape == shape), ("got shape", x.shape, "expected", shape)
return x
return specifyshape
@jax_funcify.register(Rebroadcast)
def jax_funcify_Rebroadcast(op):
op_axis = op.axis
def rebroadcast(x):
for axis, value in op_axis.items():
if value and x.shape[axis] != 1:
raise ValueError(
"Dimension %s in Rebroadcast's input was"
" supposed to be 1 (got %s instead)" % (axis, x.shape[axis])
)
return x
return rebroadcast
@jax_funcify.register(ViewOp)
def jax_funcify_ViewOp(op):
def viewop(x):
return x
return viewop
@jax_funcify.register(Cast)
def jax_funcify_Cast(op):
def cast(x):
return jnp.array(x).astype(op.o_type.dtype)
return cast
@jax_funcify.register(TensorFromScalar)
def jax_funcify_TensorFromScalar(op):
def tensor_from_scalar(x):
return jnp.array(x)
return tensor_from_scalar
@jax_funcify.register(ScalarFromTensor)
def jax_funcify_ScalarFromTensor(op):
def scalar_from_tensor(x):
return jnp.array(x).flatten()[0]
return scalar_from_tensor
@jax_funcify.register(Elemwise)
def jax_funcify_Elemwise(op):
scalar_op = op.scalar_op
return jax_funcify(scalar_op)
@jax_funcify.register(Composite)
def jax_funcify_Composite(op):
jax_impl = jax_funcify(op.fgraph)
return jax_impl
@jax_funcify.register(Scan)
def jax_funcify_Scan(op):
inner_fg = FunctionGraph(op.inputs, op.outputs)
jax_tt_inner_func = jax_funcify(inner_fg)
def scan(*outer_inputs):
scan_args = ScanArgs(
outer_inputs, [None] * op.n_outs, op.inputs, op.outputs, op.info
)
# `outer_inputs` is a list with the following composite form:
# [n_steps]
# + outer_in_seqs
# + outer_in_mit_mot
# + outer_in_mit_sot
# + outer_in_sit_sot
# + outer_in_shared
# + outer_in_nit_sot
# + outer_in_non_seqs
n_steps = scan_args.n_steps
seqs = scan_args.outer_in_seqs
n_non_seqs = len(scan_args.outer_in_non_seqs)
# TODO: sit_sots
mit_sot_in_slices = []
for tap, seq in zip(scan_args.mit_sot_in_slices, scan_args.outer_in_mit_sot):
neg_taps = [abs(t) for t in tap if t < 0]
pos_taps = [abs(t) for t in tap if t > 0]
max_neg = max(neg_taps) if neg_taps else 0
max_pos = max(pos_taps) if pos_taps else 0
init_slice = seq[: max_neg + max_pos]
mit_sot_in_slices.append(init_slice)
init_carry = [mit_sot_in_slices, scan_args.outer_in_non_seqs]
def jax_args_to_inner_scan(op, carry, x):
# `carry` contains all inner-output taps, non_seqs, and shared
# terms
(
inner_in_mit_mot,
inner_in_mit_sot,
inner_in_sit_sot,
inner_in_shared,
inner_in_non_seqs,
) = carry
# `x` contains the in_seqs
inner_in_seqs = x
# `inner_scan_inputs` is a list with the following composite form:
# inner_in_seqs
# + sum(inner_in_mit_mot, [])
# + sum(inner_in_mit_sot, [])
# + inner_in_sit_sot
# + inner_in_shared
# + inner_in_non_seqs
inner_scan_inputs = [
inner_in_seqs,
inner_in_mit_mot,
inner_in_mit_sot,
inner_in_sit_sot,
inner_in_non_seqs,
]
raise NotImplementedError()
return inner_scan_inputs
def inner_scan_outs_to_jax_outs(
op,
old_carry,
inner_scan_outs,
):
# `inner_scan_outs` is a list with the following
# composite form:
# outer_out_mit_mot
# + outer_out_mit_sot
# + outer_out_sit_sot
# + outer_out_nit_sot
# + outer_out_shared
# + cond
(
outer_out_mit_mot,
outer_out_mit_sot,
outer_out_sit_sot,
outer_out_nit_sot,
outer_out_shared,
cond,
) = inner_scan_outs
outer_out_non_seqs = old_carry[:-n_non_seqs]
# This should contain all inner-output taps, non_seqs, and shared
# terms
carry = [
outer_out_mit_mot,
outer_out_mit_sot,
outer_out_sit_sot,
outer_out_shared,
outer_out_non_seqs,
]
# This should contain all inner-outputs that produce
# outer-outputs
y = []
raise NotImplementedError()
return (carry, y)
def jax_inner_func(carry, x):
inner_args = jax_args_to_inner_scan(op, carry, x)
inner_scan_outs = jax_tt_inner_func(*inner_args)
new_carry, y = inner_scan_outs_to_jax_outs(op, inner_scan_outs)
return new_carry, y
return jax.lax.scan(jax_inner_func, init_carry, seqs, length=n_steps)
return scan
@jax_funcify.register(IfElse)
def jax_funcify_IfElse(op):
def ifelse(cond, *args):
if cond:
return args[: op.n_outs]
else:
return args[op.n_outs :]
return ifelse
def convert_indices(indices, entry):
if indices and isinstance(entry, theano.gof.Type):
rval = indices.pop(0)
return rval
elif isinstance(entry, slice):
return slice(
convert_indices(indices, entry.start),
convert_indices(indices, entry.stop),
convert_indices(indices, entry.step),
)
else:
return entry
@jax_funcify.register(Subtensor)
def jax_funcify_Subtensor(op):
idx_list = getattr(op, "idx_list", None)
def subtensor(x, *ilists):
if idx_list:
cdata = get_idx_list((x,) + ilists, idx_list)
else:
cdata = ilists
# breakpoint()
if len(cdata) == 1:
cdata = cdata[0]
return x.__getitem__(cdata)
# return x.take(ilists, axis=0)
return subtensor
_ = [jax_funcify.register(op, jax_funcify_Subtensor) for op in subtensor_ops]
def jax_funcify_IncSubtensor(op):
if getattr(op, "set_instead_of_inc", False):
jax_fn = jax.ops.index_update
else:
jax_fn = jax.ops.index_add
def incsubtensor(x, y, *ilist, jax_fn=jax_fn):
_ilist = list(ilist)
cdata = tuple(convert_indices(_ilist, idx) for idx in op.idx_list)
if len(cdata) == 1:
cdata = cdata[0]
return jax_fn(x, cdata, y)
return incsubtensor
_ = [jax_funcify.register(op, jax_funcify_IncSubtensor) for op in incsubtensor_ops]
@jax_funcify.register(FunctionGraph)
def jax_funcify_FunctionGraph(fgraph):
out_nodes = [r.owner for r in fgraph.outputs if r.owner is not None]
jax_funcs = [compose_jax_funcs(o, fgraph.inputs) for o in out_nodes]
return jax_funcs
@jax_funcify.register(CAReduce)
def jax_funcify_CAReduce(op):
axis = op.axis
op_nfunc_spec = getattr(op, "nfunc_spec", None)
scalar_nfunc_spec = getattr(op.scalar_op, "nfunc_spec", None)
scalar_op_name = getattr(op.scalar_op, "name", None)
scalar_op_identity = getattr(op.scalar_op, "identity", None)
acc_dtype = getattr(op, "acc_dtype", None)
def careduce(x):
nonlocal axis, op_nfunc_spec, scalar_nfunc_spec, scalar_op_name, scalar_op_identity, acc_dtype
if axis is None:
axis = list(range(x.ndim))
if acc_dtype is None:
acc_dtype = x.dtype.type
if op_nfunc_spec:
jax_op = getattr(jnp, op_nfunc_spec[0])
return jax_op(x, axis=axis).astype(acc_dtype)
# The Theano `Op` didn't tell us which NumPy equivalent to use (or
# there isn't one), so we use this fallback approach
if scalar_nfunc_spec:
scalar_fn_name = scalar_nfunc_spec[0]
elif scalar_op_name:
scalar_fn_name = scalar_op_name
to_reduce = reversed(sorted(axis))
if to_reduce:
# In this case, we need to use the `jax.lax` function (if there
# is one), and not the `jnp` version.
jax_op = getattr(jax.lax, scalar_fn_name)
init_value = jnp.array(scalar_op_identity, dtype=acc_dtype)
return jax.lax.reduce(x, init_value, jax_op, to_reduce).astype(acc_dtype)
else:
return x
return careduce
@jax_funcify.register(MakeVector)
def jax_funcify_MakeVector(op):
def makevector(*x):
return jnp.array(x, dtype=op.dtype)
return makevector
@jax_funcify.register(Reshape)
def jax_funcify_Reshape(op):
def reshape(x, shape):
return jnp.reshape(x, shape)
return reshape
@jax_funcify.register(DimShuffle)
def jax_funcify_DimShuffle(op):
def dimshuffle(x):
res = jnp.transpose(x, op.shuffle + op.drop)
shape = list(res.shape[: len(op.shuffle)])
for augm in op.augment:
shape.insert(augm, 1)
res = jnp.reshape(res, shape)
if not op.inplace:
res = jnp_safe_copy(res)
return res
return dimshuffle
@jax_funcify.register(Join)
def jax_funcify_Join(op):
def join(axis, *tensors):
view = op.view
if (view != -1) and all(
[
tensor.shape[axis] == 0
for tensor in tensors[0:view] + tensors[view + 1 :]
]
):
return tensors[view]
else:
ndim = tensors[0].ndim
if axis < -ndim:
raise IndexError("Join axis %d out of bounds [0, %d)" % (axis, ndim))
return jnp.concatenate(tensors, axis=axis)
return join
...@@ -1767,6 +1767,7 @@ class Maximum(BinaryScalarOp): ...@@ -1767,6 +1767,7 @@ class Maximum(BinaryScalarOp):
commutative = True commutative = True
associative = True associative = True
nfunc_spec = ("maximum", 2, 1) nfunc_spec = ("maximum", 2, 1)
nfunc_variadic = "maximum"
def impl(self, *inputs): def impl(self, *inputs):
# The built-in max function don't support complex type # The built-in max function don't support complex type
...@@ -1811,6 +1812,7 @@ class Minimum(BinaryScalarOp): ...@@ -1811,6 +1812,7 @@ class Minimum(BinaryScalarOp):
commutative = True commutative = True
associative = True associative = True
nfunc_spec = ("minimum", 2, 1) nfunc_spec = ("minimum", 2, 1)
nfunc_variadic = "minimum"
def impl(self, *inputs): def impl(self, *inputs):
# The built-in min function don't support complex type # The built-in min function don't support complex type
...@@ -1855,6 +1857,7 @@ class Add(ScalarOp): ...@@ -1855,6 +1857,7 @@ class Add(ScalarOp):
commutative = True commutative = True
associative = True associative = True
nfunc_spec = ("add", 2, 1) nfunc_spec = ("add", 2, 1)
nfunc_variadic = "sum"
def impl(self, *inputs): def impl(self, *inputs):
return sum(inputs) return sum(inputs)
...@@ -1896,6 +1899,7 @@ class Mul(ScalarOp): ...@@ -1896,6 +1899,7 @@ class Mul(ScalarOp):
commutative = True commutative = True
associative = True associative = True
nfunc_spec = ("multiply", 2, 1) nfunc_spec = ("multiply", 2, 1)
nfunc_variadic = "product"
def impl(self, *inputs): def impl(self, *inputs):
return np.product(inputs) return np.product(inputs)
...@@ -2984,6 +2988,8 @@ class Inv(UnaryScalarOp): ...@@ -2984,6 +2988,8 @@ class Inv(UnaryScalarOp):
""" """
nfunc_spec = ("reciprocal", 1, 1)
def impl(self, x): def impl(self, x):
return np.float32(1.0) / x return np.float32(1.0) / x
......
...@@ -1787,6 +1787,20 @@ def max_and_argmax(a, axis=None, keepdims=False): ...@@ -1787,6 +1787,20 @@ def max_and_argmax(a, axis=None, keepdims=False):
return [out, argout] return [out, argout]
class Max(CAReduce):
nfunc_spec = ("max", 1, 1)
def __init__(self, axis):
super().__init__(scal.maximum, axis)
class Min(CAReduce):
nfunc_spec = ("min", 1, 1)
def __init__(self, axis):
super().__init__(scal.minimum, axis)
@constructor @constructor
def max(x, axis=None, keepdims=False): def max(x, axis=None, keepdims=False):
""" """
...@@ -1823,7 +1837,7 @@ def max(x, axis=None, keepdims=False): ...@@ -1823,7 +1837,7 @@ def max(x, axis=None, keepdims=False):
try: try:
out = max_and_argmax(x, axis)[0] out = max_and_argmax(x, axis)[0]
except Exception: except Exception:
out = CAReduce(scal.maximum, axis)(x) out = Max(axis)(x)
if keepdims: if keepdims:
out = makeKeepDims(x, out, axis) out = makeKeepDims(x, out, axis)
...@@ -3416,7 +3430,7 @@ def prod( ...@@ -3416,7 +3430,7 @@ def prod(
class Mean(elemwise.CAReduce): class Mean(elemwise.CAReduce):
def __init__(self, axis=None): def __init__(self, axis=None):
elemwise.CAReduce.__init__(self, scal.add, axis) super().__init__(scal.add, axis)
assert self.axis is None or len(self.axis) == 1 assert self.axis is None or len(self.axis) == 1
def __str__(self): def __str__(self):
...@@ -3443,7 +3457,7 @@ class Mean(elemwise.CAReduce): ...@@ -3443,7 +3457,7 @@ class Mean(elemwise.CAReduce):
def c_code(self, node, name, inames, onames, sub): def c_code(self, node, name, inames, onames, sub):
if self.axis is not None: if self.axis is not None:
return super(Op, self).c_code(node, name, inames, onames, sub) return super(Op, self).c_code(node, name, inames, onames, sub)
ret = elemwise.CAReduce.c_code(self, node, name, inames, onames, sub) ret = super().c_code(self, node, name, inames, onames, sub)
# TODO: c_code perform support only axis is None # TODO: c_code perform support only axis is None
return ( return (
ret ret
......
...@@ -1761,6 +1761,7 @@ class All(CAReduce): ...@@ -1761,6 +1761,7 @@ class All(CAReduce):
""" """
__props__ = ("axis",) __props__ = ("axis",)
nfunc_spec = ("all", 1, 1)
def __init__(self, axis=None): def __init__(self, axis=None):
CAReduce.__init__(self, scalar.and_, axis) CAReduce.__init__(self, scalar.and_, axis)
...@@ -1793,6 +1794,7 @@ class Any(CAReduce): ...@@ -1793,6 +1794,7 @@ class Any(CAReduce):
""" """
__props__ = ("axis",) __props__ = ("axis",)
nfunc_spec = ("any", 1, 1)
def __init__(self, axis=None): def __init__(self, axis=None):
CAReduce.__init__(self, scalar.or_, axis) CAReduce.__init__(self, scalar.or_, axis)
...@@ -2027,6 +2029,7 @@ class Sum(CAReduceDtype): ...@@ -2027,6 +2029,7 @@ class Sum(CAReduceDtype):
""" """
__props__ = ("axis", "dtype", "acc_dtype") __props__ = ("axis", "dtype", "acc_dtype")
nfunc_spec = ("sum", 1, 1)
def __init__(self, axis=None, dtype=None, acc_dtype=None): def __init__(self, axis=None, dtype=None, acc_dtype=None):
CAReduceDtype.__init__( CAReduceDtype.__init__(
...@@ -2085,6 +2088,7 @@ class Prod(CAReduceDtype): ...@@ -2085,6 +2088,7 @@ class Prod(CAReduceDtype):
""" """
__props__ = ("axis", "dtype", "acc_dtype") __props__ = ("axis", "dtype", "acc_dtype")
nfunc_spec = ("sum", 1, 1)
def __init__(self, axis=None, dtype=None, acc_dtype=None, no_zeros_in_input=False): def __init__(self, axis=None, dtype=None, acc_dtype=None, no_zeros_in_input=False):
CAReduceDtype.__init__( CAReduceDtype.__init__(
......
...@@ -31,6 +31,8 @@ class ScalarSigmoid(scalar.UnaryScalarOp): ...@@ -31,6 +31,8 @@ class ScalarSigmoid(scalar.UnaryScalarOp):
""" """
nfunc_spec = ("scipy.special.expit", 1, 1)
@staticmethod @staticmethod
def st_impl(x): def st_impl(x):
if x < -30.0: if x < -30.0:
...@@ -196,6 +198,8 @@ class UltraFastScalarSigmoid(scalar.UnaryScalarOp): ...@@ -196,6 +198,8 @@ class UltraFastScalarSigmoid(scalar.UnaryScalarOp):
""" """
nfunc_spec = ("scipy.special.expit", 1, 1)
@staticmethod @staticmethod
def st_impl(x): def st_impl(x):
x = 0.5 * x x = 0.5 * x
......
...@@ -31,44 +31,40 @@ supposed to be canonical. ...@@ -31,44 +31,40 @@ supposed to be canonical.
""" """
# TODO: intelligent merge for mul/add
# TODO: 0*x -> 0
import logging import logging
from theano import gof import theano.tensor.basic as tt
from theano.tensor.elemwise import CAReduce import theano.scalar.basic as scal
from theano.tensor import basic as T
from theano.tensor import DimShuffle, Subtensor
from theano.gof.opt import copy_stack_trace, local_optimizer
from theano.tensor.subtensor import Subtensor
from theano.tensor.elemwise import CAReduce, DimShuffle
from theano.tensor.opt import register_uncanonicalize from theano.tensor.opt import register_uncanonicalize
from theano import scalar as scal
from theano.gof.opt import copy_stack_trace
_logger = logging.getLogger("theano.tensor.opt") _logger = logging.getLogger("theano.tensor.opt")
@register_uncanonicalize @register_uncanonicalize
@gof.local_optimizer([T.MaxAndArgmax]) @local_optimizer([tt.MaxAndArgmax])
def local_max_and_argmax(node): def local_max_and_argmax(node):
""" """
If we don't use the argmax, change it to a max only. If we don't use the argmax, change it to a max only.
""" """
if isinstance(node.op, T.MaxAndArgmax): if isinstance(node.op, tt.MaxAndArgmax):
axis = node.op.get_params(node) axis = node.op.get_params(node)
if len(node.outputs[1].clients) == 0: if len(node.outputs[1].clients) == 0:
new = CAReduce(scal.maximum, axis)(node.inputs[0]) new = tt.Max(axis)(node.inputs[0])
copy_stack_trace(node.outputs[0], new) copy_stack_trace(node.outputs[0], new)
return [new, None] return [new, None]
if len(node.outputs[0].clients) == 0: if len(node.outputs[0].clients) == 0:
new = T.Argmax(axis)(node.inputs[0]) new = tt.Argmax(axis)(node.inputs[0])
copy_stack_trace(node.outputs[0], new) copy_stack_trace(node.outputs[0], new)
return [None, new] return [None, new]
@register_uncanonicalize @register_uncanonicalize
@gof.local_optimizer([T.neg]) @local_optimizer([tt.neg])
def local_max_to_min(node): def local_max_to_min(node):
""" """
Change -(max(-x)) to min. Change -(max(-x)) to min.
...@@ -81,7 +77,7 @@ def local_max_to_min(node): ...@@ -81,7 +77,7 @@ def local_max_to_min(node):
the interface put only MaxAndArgmax into the graph. the interface put only MaxAndArgmax into the graph.
""" """
if node.op == T.neg and node.inputs[0].owner: if node.op == tt.neg and node.inputs[0].owner:
max = node.inputs[0] max = node.inputs[0]
if ( if (
max.owner max.owner
...@@ -89,15 +85,15 @@ def local_max_to_min(node): ...@@ -89,15 +85,15 @@ def local_max_to_min(node):
and max.owner.op.scalar_op == scal.maximum and max.owner.op.scalar_op == scal.maximum
): ):
neg = max.owner.inputs[0] neg = max.owner.inputs[0]
if neg.owner and neg.owner.op == T.neg: if neg.owner and neg.owner.op == tt.neg:
new = CAReduce(scal.minimum, max.owner.op.axis)(neg.owner.inputs[0]) new = tt.Min(max.owner.op.axis)(neg.owner.inputs[0])
return [copy_stack_trace(node.outputs[0], new)] return [copy_stack_trace(node.outputs[0], new)]
return False return False
@register_uncanonicalize @register_uncanonicalize
@gof.local_optimizer([T.Alloc]) @local_optimizer([tt.Alloc])
def local_alloc_dimshuffle(node): def local_alloc_dimshuffle(node):
""" """
If a dimshuffle is inside an alloc and only adds dimension to the If a dimshuffle is inside an alloc and only adds dimension to the
...@@ -105,7 +101,7 @@ def local_alloc_dimshuffle(node): ...@@ -105,7 +101,7 @@ def local_alloc_dimshuffle(node):
Alloc(DimShuffle(x), ...) - > Alloc(x, ...) Alloc(DimShuffle(x), ...) - > Alloc(x, ...)
""" """
if isinstance(node.op, T.Alloc): if isinstance(node.op, tt.Alloc):
input_ = node.inputs[0] input_ = node.inputs[0]
if input_.owner and isinstance(input_.owner.op, DimShuffle): if input_.owner and isinstance(input_.owner.op, DimShuffle):
# check if it only adds dimension to the left # check if it only adds dimension to the left
...@@ -115,12 +111,12 @@ def local_alloc_dimshuffle(node): ...@@ -115,12 +111,12 @@ def local_alloc_dimshuffle(node):
) + tuple(range(input_.owner.inputs[0].ndim)) ) + tuple(range(input_.owner.inputs[0].ndim))
if new_order != expected_new_order: if new_order != expected_new_order:
return False return False
return [T.alloc(input_.owner.inputs[0], *node.inputs[1:])] return [tt.alloc(input_.owner.inputs[0], *node.inputs[1:])]
return False return False
@register_uncanonicalize @register_uncanonicalize
@gof.local_optimizer([T.Reshape]) @local_optimizer([tt.Reshape])
def local_reshape_dimshuffle(node): def local_reshape_dimshuffle(node):
""" """
If a dimshuffle is inside a reshape and does not change the order If a dimshuffle is inside a reshape and does not change the order
...@@ -128,7 +124,7 @@ def local_reshape_dimshuffle(node): ...@@ -128,7 +124,7 @@ def local_reshape_dimshuffle(node):
Reshape(Dimshuffle(x), shp) -> Reshape(x, shp) Reshape(Dimshuffle(x), shp) -> Reshape(x, shp)
""" """
if isinstance(node.op, T.Reshape): if isinstance(node.op, tt.Reshape):
input_ = node.inputs[0] input_ = node.inputs[0]
if input_.owner and isinstance(input_.owner.op, DimShuffle): if input_.owner and isinstance(input_.owner.op, DimShuffle):
new_order = input_.owner.op.new_order new_order = input_.owner.op.new_order
...@@ -141,7 +137,7 @@ def local_reshape_dimshuffle(node): ...@@ -141,7 +137,7 @@ def local_reshape_dimshuffle(node):
else: else:
offset += 1 offset += 1
return [ return [
T.reshape( tt.reshape(
input_.owner.inputs[0], node.inputs[1], ndim=node.outputs[0].ndim input_.owner.inputs[0], node.inputs[1], ndim=node.outputs[0].ndim
) )
] ]
...@@ -149,7 +145,7 @@ def local_reshape_dimshuffle(node): ...@@ -149,7 +145,7 @@ def local_reshape_dimshuffle(node):
@register_uncanonicalize @register_uncanonicalize
@gof.local_optimizer([DimShuffle]) @local_optimizer([DimShuffle])
def local_dimshuffle_alloc(node): def local_dimshuffle_alloc(node):
""" """
If an alloc is inside a dimshuffle which only adds dimension to the left, If an alloc is inside a dimshuffle which only adds dimension to the left,
...@@ -159,7 +155,7 @@ def local_dimshuffle_alloc(node): ...@@ -159,7 +155,7 @@ def local_dimshuffle_alloc(node):
""" """
if isinstance(node.op, DimShuffle) and node.inputs[0].owner: if isinstance(node.op, DimShuffle) and node.inputs[0].owner:
input_ = node.inputs[0] input_ = node.inputs[0]
if isinstance(input_.owner.op, T.Alloc): if isinstance(input_.owner.op, tt.Alloc):
# check if it only adds dimension to the left # check if it only adds dimension to the left
new_order = node.op.new_order new_order = node.op.new_order
expected_new_order = ("x",) * (len(new_order) - input_.ndim) + tuple( expected_new_order = ("x",) * (len(new_order) - input_.ndim) + tuple(
...@@ -172,12 +168,12 @@ def local_dimshuffle_alloc(node): ...@@ -172,12 +168,12 @@ def local_dimshuffle_alloc(node):
nb_new_dims = len(new_order) - input_.ndim nb_new_dims = len(new_order) - input_.ndim
new_shape_input = (1,) * nb_new_dims + tuple(input_.owner.inputs[1:]) new_shape_input = (1,) * nb_new_dims + tuple(input_.owner.inputs[1:])
return [T.alloc(input_.owner.inputs[0], *new_shape_input)] return [tt.alloc(input_.owner.inputs[0], *new_shape_input)]
return False return False
@register_uncanonicalize @register_uncanonicalize
@gof.local_optimizer([DimShuffle]) @local_optimizer([DimShuffle])
def local_dimshuffle_subtensor(node): def local_dimshuffle_subtensor(node):
"""If a subtensor is inside a dimshuffle which only drop """If a subtensor is inside a dimshuffle which only drop
broadcastable dimensions, scrap the dimshuffle and index the broadcastable dimensions, scrap the dimshuffle and index the
...@@ -223,7 +219,7 @@ def local_dimshuffle_subtensor(node): ...@@ -223,7 +219,7 @@ def local_dimshuffle_subtensor(node):
# tensor was indexed such as x[scalar, :, :], check that as well # tensor was indexed such as x[scalar, :, :], check that as well
new_idx_list = list(input_.owner.op.idx_list) new_idx_list = list(input_.owner.op.idx_list)
new_inputs = [input_.owner.inputs[0]] new_inputs = [input_.owner.inputs[0]]
zero = T.constant(0) zero = tt.constant(0)
slice_attr_list = ["start", "stop", "step"] slice_attr_list = ["start", "stop", "step"]
j = 0 j = 0
slice_i = -1 slice_i = -1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论