Unverified 提交 8924d294 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: GitHub

Merge pull request #80 from brandonwillard/add-cumop-jaxification

Add JAX conversions for theano.tensor.extra_ops
......@@ -19,17 +19,31 @@ def set_theano_flags():
def compare_jax_and_py(
fgraph, inputs, assert_fn=partial(np.testing.assert_allclose, rtol=1e-4)
fgraph,
inputs,
assert_fn=partial(np.testing.assert_allclose, rtol=1e-4),
simplify=False,
must_be_device_array=True,
):
theano_jax_fn = theano.function(fgraph.inputs, fgraph.outputs, mode="JAX")
if not simplify:
opts = theano.gof.Query(include=[None], exclude=["cxx_only", "BlasOpt"])
jax_mode = theano.compile.mode.Mode(theano.sandbox.jax_linker.JAXLinker(), opts)
py_mode = theano.compile.Mode("py", opts)
else:
py_mode = theano.compile.Mode(linker="py")
jax_mode = "JAX"
theano_jax_fn = theano.function(fgraph.inputs, fgraph.outputs, mode=jax_mode)
jax_res = theano_jax_fn(*inputs)
if isinstance(jax_res, list):
assert all(isinstance(res, jax.interpreters.xla.DeviceArray) for res in jax_res)
else:
assert isinstance(jax_res, jax.interpreters.xla.DeviceArray)
if must_be_device_array:
if isinstance(jax_res, list):
assert all(
isinstance(res, jax.interpreters.xla.DeviceArray) for res in jax_res
)
else:
assert isinstance(jax_res, jax.interpreters.xla.DeviceArray)
py_mode = theano.compile.Mode(linker="py")
theano_py_fn = theano.function(fgraph.inputs, fgraph.outputs, mode=py_mode)
py_res = theano_py_fn(*inputs)
......@@ -88,12 +102,12 @@ def test_jax_compile_ops():
x = theano.compile.ops.Shape()(tt.as_tensor_variable(x_np))
x_fg = theano.gof.FunctionGraph([], [x])
compare_jax_and_py(x_fg, [])
compare_jax_and_py(x_fg, [], must_be_device_array=False)
x = theano.compile.ops.Shape_i(1)(tt.as_tensor_variable(x_np))
x_fg = theano.gof.FunctionGraph([], [x])
compare_jax_and_py(x_fg, [])
compare_jax_and_py(x_fg, [], must_be_device_array=False)
x = theano.compile.ops.SpecifyShape()(tt.as_tensor_variable(x_np), (20, 3))
x_fg = theano.gof.FunctionGraph([], [x])
......@@ -340,7 +354,7 @@ def test_jax_Subtensors():
def test_jax_IncSubtensor():
x_np = np.empty((3, 4, 5), dtype=tt.config.floatX)
x_np = np.random.uniform(-1, 1, size=(3, 4, 5)).astype(tt.config.floatX)
x_tt = tt.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(tt.config.floatX)
# "Set" basic indices
......@@ -410,6 +424,8 @@ def test_jax_IncSubtensor():
def test_jax_ifelse():
import theano.ifelse
true_vals = np.r_[1, 2, 3]
false_vals = np.r_[-1, -2, -3]
......@@ -648,3 +664,59 @@ def test_shared():
jax_res = theano_jax_fn()
assert isinstance(jax_res, jax.interpreters.xla.DeviceArray)
np.testing.assert_allclose(jax_res, new_a_value * 2)
def test_extra_ops():
a = tt.matrix("a")
a.tag.test_value = np.arange(6, dtype=theano.config.floatX).reshape((3, 2))
out = tt.extra_ops.cumsum(a, axis=0)
fgraph = theano.gof.FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = tt.extra_ops.cumprod(a, axis=1)
fgraph = theano.gof.FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = tt.extra_ops.diff(a, n=2, axis=1)
fgraph = theano.gof.FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = tt.extra_ops.repeat(a, (3, 3), axis=1)
fgraph = theano.gof.FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
# This function also cannot take symbolic input.
c = tt.as_tensor(5)
out = tt.extra_ops.bartlett(c)
fgraph = theano.gof.FunctionGraph([], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
with pytest.raises(NotImplementedError):
out = tt.extra_ops.fill_diagonal(a, c)
fgraph = theano.gof.FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
with pytest.raises(NotImplementedError):
out = tt.extra_ops.fill_diagonal_offset(a, c, c)
fgraph = theano.gof.FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
with pytest.raises(NotImplementedError):
out = tt.extra_ops.Unique(axis=1)(a)
fgraph = theano.gof.FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
indices = np.arange(np.product((3, 4)))
out = tt.extra_ops.unravel_index(indices, (3, 4), order="C")
fgraph = theano.gof.FunctionGraph([], out)
compare_jax_and_py(
fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False
)
multi_index = np.unravel_index(np.arange(np.product((3, 4))), (3, 4))
out = tt.extra_ops.ravel_multi_index(multi_index, (3, 4))
fgraph = theano.gof.FunctionGraph([], [out])
compare_jax_and_py(
fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False
)
......@@ -1065,13 +1065,10 @@ class TestUnravelIndex(utt.InferShapeTester):
indices_symb = theano.shared(indices)
# reference result
ref = np.unravel_index(indices, shape)
ref = np.unravel_index(indices, shape, order=order)
def fn(i, d, nd=None):
if nd is None:
return function([], unravel_index(i, d, order=order))
else:
return function([], unravel_index(i, d, order=order, ndim=nd))
def fn(i, d):
return function([], unravel_index(i, d, order=order))
# shape given as a tuple
f_array_tuple = fn(indices, shape)
......@@ -1086,7 +1083,7 @@ class TestUnravelIndex(utt.InferShapeTester):
# shape given as a theano variable
shape_symb = theano.shared(shape_array)
f_array_symb = fn(indices, shape_symb, len(shape))
f_array_symb = fn(indices, shape_symb)
np.testing.assert_equal(ref, f_array_symb())
# shape given as a Shape op (unravel_index will use get_vector_length
......@@ -1098,7 +1095,7 @@ class TestUnravelIndex(utt.InferShapeTester):
# shape testing
self._compile_and_check(
[],
unravel_index(indices, shape_symb, order=order, ndim=len(shape)),
unravel_index(indices, shape_symb, order=order),
[],
UnravelIndex,
)
......@@ -1118,8 +1115,6 @@ class TestUnravelIndex(utt.InferShapeTester):
unravel_index(theano.tensor.fvector(), (3, 4))
with pytest.raises(TypeError):
unravel_index((3, 4), (3.4, 3.2))
with pytest.raises(ValueError):
unravel_index((3, 4), (3, 3), ndim=5.4)
# dims must be a 1D sequence
with pytest.raises(TypeError):
......
......@@ -35,6 +35,7 @@ from theano.tensor.basic import (
Alloc,
Reshape,
Join,
MaxAndArgmax,
)
from theano.scalar.basic import ScalarOp, Composite, Cast, Clip, Identity
from theano.tensor.elemwise import Elemwise, CAReduce, DimShuffle
......@@ -67,6 +68,21 @@ from theano.tensor.slinalg import (
Solve,
)
from theano.tensor.type_other import MakeSlice
from theano.tensor.extra_ops import (
CumOp,
DiffOp,
RepeatOp,
Bartlett,
FillDiagonal,
FillDiagonalOffset,
Unique,
UnravelIndex,
RavelMultiIndex,
)
if theano.config.floatX == "float64":
jax.config.update("jax_enable_x64", True)
else:
......@@ -82,7 +98,7 @@ except AttributeError:
pass
subtensor_ops = (Subtensor, AdvancedSubtensor1, BaseAdvancedSubtensor)
incsubtensor_ops = (IncSubtensor, AdvancedIncSubtensor1, BaseAdvancedIncSubtensor)
incsubtensor_ops = (IncSubtensor, AdvancedIncSubtensor1)
def compose_jax_funcs(out_node, fgraph_inputs, memo=None):
......@@ -116,15 +132,23 @@ def compose_jax_funcs(out_node, fgraph_inputs, memo=None):
if i in fgraph_inputs:
idx = fgraph_inputs.index(i)
def jax_inputs_func(*inputs, i_dtype=i.dtype, idx=idx):
i_dtype = getattr(i, "dtype", None)
def jax_inputs_func(*inputs, i_dtype=i_dtype, idx=idx):
return jnp.array(inputs[idx], dtype=jnp.dtype(i_dtype))
input_f = jax_inputs_func
elif i.owner is None:
def jax_data_func(*inputs, i_dtype=i.dtype, i_data=i.data):
return jnp.array(i_data, dtype=jnp.dtype(i_dtype))
i_dtype = getattr(i, "dtype", None)
i_data = i.data
def jax_data_func(*inputs, i_dtype=i_dtype, i_data=i_data):
if i_dtype is None:
return i_data
else:
return jnp.array(i_data, dtype=jnp.dtype(i_dtype))
input_f = jax_data_func
else:
......@@ -158,6 +182,14 @@ def jax_funcify(op):
raise NotImplementedError("No JAX conversion for the given `Op`: {}".format(op))
@jax_funcify.register(MakeSlice)
def jax_funcify_MakeSlice(op):
def makeslice(*x):
return slice(*x)
return makeslice
@jax_funcify.register(ScalarOp)
def jax_funcify_ScalarOp(op):
func_name = op.nfunc_spec[0]
......@@ -288,8 +320,13 @@ def jax_funcify_Shape_i(op):
@jax_funcify.register(SpecifyShape)
def jax_funcify_SpecifyShape(op):
def specifyshape(x, shape):
assert x.ndim == shape.size
assert jnp.all(x.shape == shape), ("got shape", x.shape, "expected", shape)
assert x.ndim == len(shape)
assert jnp.all(x.shape == tuple(shape)), (
"got shape",
x.shape,
"expected",
shape,
)
return x
return specifyshape
......@@ -475,11 +512,15 @@ def jax_funcify_Scan(op):
@jax_funcify.register(IfElse)
def jax_funcify_IfElse(op):
def ifelse(cond, *args):
n_outs = op.n_outs
def ifelse(cond, *args, n_outs=n_outs):
if cond:
return args[: op.n_outs]
res = args[:n_outs]
else:
return args[op.n_outs :]
res = args[n_outs:]
return res if n_outs > 1 else res[0]
return ifelse
......@@ -526,14 +567,16 @@ _ = [jax_funcify.register(op, jax_funcify_Subtensor) for op in subtensor_ops]
def jax_funcify_IncSubtensor(op):
idx_list = op.idx_list
if getattr(op, "set_instead_of_inc", False):
jax_fn = jax.ops.index_update
else:
jax_fn = jax.ops.index_add
def incsubtensor(x, y, *ilist, jax_fn=jax_fn):
def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list):
_ilist = list(ilist)
cdata = tuple(convert_indices(_ilist, idx) for idx in op.idx_list)
cdata = tuple(convert_indices(_ilist, idx) for idx in idx_list)
if len(cdata) == 1:
cdata = cdata[0]
......@@ -545,6 +588,20 @@ def jax_funcify_IncSubtensor(op):
_ = [jax_funcify.register(op, jax_funcify_IncSubtensor) for op in incsubtensor_ops]
@jax_funcify.register(BaseAdvancedIncSubtensor)
def jax_funcify_BaseAdvancedIncSubtensor(op):
if getattr(op, "set_instead_of_inc", False):
jax_fn = jax.ops.index_update
else:
jax_fn = jax.ops.index_add
def baseadvancedincsubtensor(x, y, *ilist, jax_fn=jax_fn):
return jax_fn(x, ilist, y)
return baseadvancedincsubtensor
@jax_funcify.register(FunctionGraph)
def jax_funcify_FunctionGraph(fgraph):
......@@ -656,6 +713,44 @@ def jax_funcify_Join(op):
return join
@jax_funcify.register(MaxAndArgmax)
def jax_funcify_MaxAndArgmax(op):
axis = op.axis
def maxandargmax(x, axis=axis):
if axis is None:
axes = tuple(range(x.ndim))
else:
axes = tuple(int(ax) for ax in axis)
max_res = jnp.max(x, axis)
# NumPy does not support multiple axes for argmax; this is a
# work-around
keep_axes = jnp.array(
[i for i in range(x.ndim) if i not in axes], dtype="int64"
)
# Not-reduced axes in front
transposed_x = jnp.transpose(
x, jnp.concatenate((keep_axes, jnp.array(axes, dtype="int64")))
)
kept_shape = transposed_x.shape[: len(keep_axes)]
reduced_shape = transposed_x.shape[len(keep_axes) :]
# Numpy.prod returns 1.0 when arg is empty, so we cast it to int64
# Otherwise reshape would complain citing float arg
new_shape = kept_shape + (
jnp.prod(jnp.array(reduced_shape, dtype="int64"), dtype="int64"),
)
reshaped_x = transposed_x.reshape(new_shape)
max_idx_res = jnp.argmax(reshaped_x, axis=-1).astype("int64")
return max_res, max_idx_res
return maxandargmax
@jax_funcify.register(ExtractDiag)
def jax_funcify_ExtractDiag(op):
offset = op.offset
......@@ -763,3 +858,141 @@ def jax_funcify_SVD(op):
return jnp.linalg.svd(x, full_matrices=full_matrices, compute_uv=compute_uv)
return svd
@jax_funcify.register(CumOp)
def jax_funcify_CumOp(op):
axis = op.axis
mode = op.mode
def cumop(x, axis=axis, mode=mode):
if mode == "add":
return jnp.cumsum(x, axis=axis)
else:
return jnp.cumprod(x, axis=axis)
return cumop
@jax_funcify.register(DiffOp)
def jax_funcify_DiffOp(op):
n = op.n
axis = op.axis
def diffop(x, n=n, axis=axis):
return jnp.diff(x, n=n, axis=axis)
return diffop
@jax_funcify.register(RepeatOp)
def jax_funcify_RepeatOp(op):
axis = op.axis
def repeatop(x, repeats, axis=axis):
return jnp.repeat(x, repeats, axis=axis)
return repeatop
@jax_funcify.register(Bartlett)
def jax_funcify_Bartlett(op):
def bartlett(x):
return jnp.bartlett(x)
return bartlett
@jax_funcify.register(FillDiagonal)
def jax_funcify_FillDiagonal(op):
# def filldiagonal(a, val):
# if a.ndim == 2:
# step = a.shape[1] + 1
# end = a.shape[1] * a.shape[1]
# a.flat[:end:step] = val
# else:
# jnp.fill_diagonal(a, val)
#
# return a
#
# return filldiagonal
raise NotImplementedError("flatiter not implemented in JAX")
@jax_funcify.register(FillDiagonalOffset)
def jax_funcify_FillDiagonalOffset(op):
# def filldiagonaloffset(a, val, offset):
# height, width = a.shape
#
# if offset >= 0:
# start = offset
# num_of_step = min(min(width, height), width - offset)
# else:
# start = -offset * a.shape[1]
# num_of_step = min(min(width, height), height + offset)
#
# step = a.shape[1] + 1
# end = start + step * num_of_step
# a.flat[start:end:step] = val
#
# return a
#
# return filldiagonaloffset
raise NotImplementedError("flatiter not implemented in JAX")
@jax_funcify.register(Unique)
def jax_funcify_Unique(op):
return_index = op.return_index
return_inverse = op.return_inverse
return_counts = op.return_counts
axis = op.axis
def unique(
x,
return_index=return_index,
return_inverse=return_inverse,
return_counts=return_counts,
axis=axis,
):
param = {}
if return_index:
param["return_index"] = True
if return_inverse:
param["return_inverse"] = True
if return_counts:
param["return_counts"] = True
if axis is not None:
param["axis"] = axis
return jnp.unique(x, **param)
return unique
@jax_funcify.register(UnravelIndex)
def jax_funcify_UnravelIndex(op):
order = op.order
warn("JAX ignores the `order` parameter in `unravel_index`.")
def unravelindex(indices, dims, order=order):
return jnp.unravel_index(indices, dims)
return unravelindex
@jax_funcify.register(RavelMultiIndex)
def jax_funcify_RavelMultiIndex(op):
mode = op.mode
order = op.order
def ravelmultiindex(*inp, mode=mode, order=order):
multi_index, dims = inp[:-1], inp[-1]
return jnp.ravel_multi_index(multi_index, dims, mode=mode, order=order)
return ravelmultiindex
......@@ -5002,6 +5002,8 @@ def get_vector_length(v):
raise TypeError("argument must be symbolic vector, got '%s'" % v)
if v.type.broadcastable[0]:
return 1
if isinstance(v, theano.tensor.sharedvar.TensorSharedVariable) and v.type.ndim == 1:
return len(v.get_value())
if isinstance(v, gof.Constant) and v.type.ndim == 1:
return len(v.data)
if v.owner and isinstance(v.owner.op, theano.tensor.opt.MakeVector):
......
......@@ -294,7 +294,10 @@ class CumOp(theano.Op):
def perform(self, node, inputs, output_storage, params):
x = inputs[0]
z = output_storage[0]
z[0] = {"add": np.cumsum, "mul": np.cumprod}[self.mode](x, axis=self.axis)
if self.mode == "add":
z[0] = np.cumsum(x, axis=self.axis)
else:
z[0] = np.cumprod(x, axis=self.axis)
def grad(self, inputs, output_gradients):
(x,) = inputs
......@@ -1289,13 +1292,10 @@ class Unique(theano.Op):
class UnravelIndex(gof.Op):
__props__ = ("ndim", "order")
__props__ = ("order",)
def __init__(self, ndim, order="C"):
def __init__(self, order="C"):
assert order in ("C", "F")
if not isinstance(ndim, int) or ndim < 1:
raise ValueError("ndim must be an integer greater than 0")
self.ndim = int(ndim)
self.order = order
def make_node(self, indices, dims):
......@@ -1318,7 +1318,7 @@ class UnravelIndex(gof.Op):
[indices, dims],
[
basic.TensorType(dtype="int64", broadcastable=(False,) * indices.ndim)()
for i in range(self.ndim)
for i in range(basic.get_vector_length(dims))
],
)
......@@ -1327,7 +1327,7 @@ class UnravelIndex(gof.Op):
def perform(self, node, inp, out):
indices, dims = inp
res = np.unravel_index(indices, dims)
res = np.unravel_index(indices, dims, order=self.order)
assert len(res) == len(out)
for i in range(len(out)):
ret = theano._asarray(res[i], node.outputs[0].dtype)
......@@ -1338,15 +1338,11 @@ class UnravelIndex(gof.Op):
out[i][0] = ret
def unravel_index(indices, dims, order="C", ndim=None):
def unravel_index(indices, dims, order="C"):
"""
Converts a flat index or array of flat indices into a tuple
of coordinate arrays.
This method is similar to the NumPy version, except for the
additional ``ndim`` parameter. This parameter is required if
the length of ``dims`` cannot be determined automatically.
Parameters
----------
indices : Theano or NumPy array
......@@ -1357,10 +1353,6 @@ def unravel_index(indices, dims, order="C", ndim=None):
order : {'C', 'F'}, optional
Determines whether the indices should be viewed as indexing in
row-major (C-style) or column-major (Fortran-style) order.
ndim : int, optional
Specifies the number of dimensions, i.e., the length of
``dims``. This is required if the dimensions cannot be determined
automatically from ``dims`` itself.
Returns
-------
......@@ -1373,20 +1365,8 @@ def unravel_index(indices, dims, order="C", ndim=None):
ravel_multi_index
"""
if ndim is None:
try:
ndim = basic.get_vector_length(dims)
except ValueError:
raise ValueError(
"The length of the provided dimension list (%s) cannot "
"be automatically determined, so Theano is not able "
"to know what the number of dimensions of the unraveled "
"index will be. You can provide the 'ndim' keyword "
"argument to 'unravel_index' to avoid this problem." % str(dims)
)
res = UnravelIndex(ndim=ndim, order=order)(indices, dims)
if ndim == 1:
res = UnravelIndex(order=order)(indices, dims)
if not isinstance(res, (list, tuple)):
return (res,)
else:
return tuple(res)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论