提交 fe0365ad authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Implement new JAX conversions for theano.tensor.extra_ops

上级 e464ba49
...@@ -664,3 +664,59 @@ def test_shared(): ...@@ -664,3 +664,59 @@ def test_shared():
jax_res = theano_jax_fn() jax_res = theano_jax_fn()
assert isinstance(jax_res, jax.interpreters.xla.DeviceArray) assert isinstance(jax_res, jax.interpreters.xla.DeviceArray)
np.testing.assert_allclose(jax_res, new_a_value * 2) np.testing.assert_allclose(jax_res, new_a_value * 2)
def test_extra_ops():
a = tt.matrix("a")
a.tag.test_value = np.arange(6, dtype=theano.config.floatX).reshape((3, 2))
out = tt.extra_ops.cumsum(a, axis=0)
fgraph = theano.gof.FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = tt.extra_ops.cumprod(a, axis=1)
fgraph = theano.gof.FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = tt.extra_ops.diff(a, n=2, axis=1)
fgraph = theano.gof.FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = tt.extra_ops.repeat(a, (3, 3), axis=1)
fgraph = theano.gof.FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
# This function also cannot take symbolic input.
c = tt.as_tensor(5)
out = tt.extra_ops.bartlett(c)
fgraph = theano.gof.FunctionGraph([], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
with pytest.raises(NotImplementedError):
out = tt.extra_ops.fill_diagonal(a, c)
fgraph = theano.gof.FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
with pytest.raises(NotImplementedError):
out = tt.extra_ops.fill_diagonal_offset(a, c, c)
fgraph = theano.gof.FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
with pytest.raises(NotImplementedError):
out = tt.extra_ops.Unique(axis=1)(a)
fgraph = theano.gof.FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
indices = np.arange(np.product((3, 4)))
out = tt.extra_ops.unravel_index(indices, (3, 4), order="C")
fgraph = theano.gof.FunctionGraph([], out)
compare_jax_and_py(
fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False
)
multi_index = np.unravel_index(np.arange(np.product((3, 4))), (3, 4))
out = tt.extra_ops.ravel_multi_index(multi_index, (3, 4))
fgraph = theano.gof.FunctionGraph([], [out])
compare_jax_and_py(
fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False
)
...@@ -35,6 +35,7 @@ from theano.tensor.basic import ( ...@@ -35,6 +35,7 @@ from theano.tensor.basic import (
Alloc, Alloc,
Reshape, Reshape,
Join, Join,
MaxAndArgmax,
) )
from theano.scalar.basic import ScalarOp, Composite, Cast, Clip, Identity from theano.scalar.basic import ScalarOp, Composite, Cast, Clip, Identity
from theano.tensor.elemwise import Elemwise, CAReduce, DimShuffle from theano.tensor.elemwise import Elemwise, CAReduce, DimShuffle
...@@ -67,6 +68,21 @@ from theano.tensor.slinalg import ( ...@@ -67,6 +68,21 @@ from theano.tensor.slinalg import (
Solve, Solve,
) )
from theano.tensor.type_other import MakeSlice
from theano.tensor.extra_ops import (
CumOp,
DiffOp,
RepeatOp,
Bartlett,
FillDiagonal,
FillDiagonalOffset,
Unique,
UnravelIndex,
RavelMultiIndex,
)
if theano.config.floatX == "float64": if theano.config.floatX == "float64":
jax.config.update("jax_enable_x64", True) jax.config.update("jax_enable_x64", True)
else: else:
...@@ -82,7 +98,7 @@ except AttributeError: ...@@ -82,7 +98,7 @@ except AttributeError:
pass pass
subtensor_ops = (Subtensor, AdvancedSubtensor1, BaseAdvancedSubtensor) subtensor_ops = (Subtensor, AdvancedSubtensor1, BaseAdvancedSubtensor)
incsubtensor_ops = (IncSubtensor, AdvancedIncSubtensor1, BaseAdvancedIncSubtensor) incsubtensor_ops = (IncSubtensor, AdvancedIncSubtensor1)
def compose_jax_funcs(out_node, fgraph_inputs, memo=None): def compose_jax_funcs(out_node, fgraph_inputs, memo=None):
...@@ -116,15 +132,23 @@ def compose_jax_funcs(out_node, fgraph_inputs, memo=None): ...@@ -116,15 +132,23 @@ def compose_jax_funcs(out_node, fgraph_inputs, memo=None):
if i in fgraph_inputs: if i in fgraph_inputs:
idx = fgraph_inputs.index(i) idx = fgraph_inputs.index(i)
def jax_inputs_func(*inputs, i_dtype=i.dtype, idx=idx): i_dtype = getattr(i, "dtype", None)
def jax_inputs_func(*inputs, i_dtype=i_dtype, idx=idx):
return jnp.array(inputs[idx], dtype=jnp.dtype(i_dtype)) return jnp.array(inputs[idx], dtype=jnp.dtype(i_dtype))
input_f = jax_inputs_func input_f = jax_inputs_func
elif i.owner is None: elif i.owner is None:
def jax_data_func(*inputs, i_dtype=i.dtype, i_data=i.data): i_dtype = getattr(i, "dtype", None)
return jnp.array(i_data, dtype=jnp.dtype(i_dtype)) i_data = i.data
def jax_data_func(*inputs, i_dtype=i_dtype, i_data=i_data):
if i_dtype is None:
return i_data
else:
return jnp.array(i_data, dtype=jnp.dtype(i_dtype))
input_f = jax_data_func input_f = jax_data_func
else: else:
...@@ -158,6 +182,14 @@ def jax_funcify(op): ...@@ -158,6 +182,14 @@ def jax_funcify(op):
raise NotImplementedError("No JAX conversion for the given `Op`: {}".format(op)) raise NotImplementedError("No JAX conversion for the given `Op`: {}".format(op))
@jax_funcify.register(MakeSlice)
def jax_funcify_MakeSlice(op):
def makeslice(*x):
return slice(*x)
return makeslice
@jax_funcify.register(ScalarOp) @jax_funcify.register(ScalarOp)
def jax_funcify_ScalarOp(op): def jax_funcify_ScalarOp(op):
func_name = op.nfunc_spec[0] func_name = op.nfunc_spec[0]
...@@ -288,8 +320,13 @@ def jax_funcify_Shape_i(op): ...@@ -288,8 +320,13 @@ def jax_funcify_Shape_i(op):
@jax_funcify.register(SpecifyShape) @jax_funcify.register(SpecifyShape)
def jax_funcify_SpecifyShape(op): def jax_funcify_SpecifyShape(op):
def specifyshape(x, shape): def specifyshape(x, shape):
assert x.ndim == shape.size assert x.ndim == len(shape)
assert jnp.all(x.shape == shape), ("got shape", x.shape, "expected", shape) assert jnp.all(x.shape == tuple(shape)), (
"got shape",
x.shape,
"expected",
shape,
)
return x return x
return specifyshape return specifyshape
...@@ -475,11 +512,15 @@ def jax_funcify_Scan(op): ...@@ -475,11 +512,15 @@ def jax_funcify_Scan(op):
@jax_funcify.register(IfElse) @jax_funcify.register(IfElse)
def jax_funcify_IfElse(op): def jax_funcify_IfElse(op):
def ifelse(cond, *args): n_outs = op.n_outs
def ifelse(cond, *args, n_outs=n_outs):
if cond: if cond:
return args[: op.n_outs] res = args[:n_outs]
else: else:
return args[op.n_outs :] res = args[n_outs:]
return res if n_outs > 1 else res[0]
return ifelse return ifelse
...@@ -526,14 +567,16 @@ _ = [jax_funcify.register(op, jax_funcify_Subtensor) for op in subtensor_ops] ...@@ -526,14 +567,16 @@ _ = [jax_funcify.register(op, jax_funcify_Subtensor) for op in subtensor_ops]
def jax_funcify_IncSubtensor(op): def jax_funcify_IncSubtensor(op):
idx_list = op.idx_list
if getattr(op, "set_instead_of_inc", False): if getattr(op, "set_instead_of_inc", False):
jax_fn = jax.ops.index_update jax_fn = jax.ops.index_update
else: else:
jax_fn = jax.ops.index_add jax_fn = jax.ops.index_add
def incsubtensor(x, y, *ilist, jax_fn=jax_fn): def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list):
_ilist = list(ilist) _ilist = list(ilist)
cdata = tuple(convert_indices(_ilist, idx) for idx in op.idx_list) cdata = tuple(convert_indices(_ilist, idx) for idx in idx_list)
if len(cdata) == 1: if len(cdata) == 1:
cdata = cdata[0] cdata = cdata[0]
...@@ -545,6 +588,20 @@ def jax_funcify_IncSubtensor(op): ...@@ -545,6 +588,20 @@ def jax_funcify_IncSubtensor(op):
_ = [jax_funcify.register(op, jax_funcify_IncSubtensor) for op in incsubtensor_ops] _ = [jax_funcify.register(op, jax_funcify_IncSubtensor) for op in incsubtensor_ops]
@jax_funcify.register(BaseAdvancedIncSubtensor)
def jax_funcify_BaseAdvancedIncSubtensor(op):
if getattr(op, "set_instead_of_inc", False):
jax_fn = jax.ops.index_update
else:
jax_fn = jax.ops.index_add
def baseadvancedincsubtensor(x, y, *ilist, jax_fn=jax_fn):
return jax_fn(x, ilist, y)
return baseadvancedincsubtensor
@jax_funcify.register(FunctionGraph) @jax_funcify.register(FunctionGraph)
def jax_funcify_FunctionGraph(fgraph): def jax_funcify_FunctionGraph(fgraph):
...@@ -656,6 +713,44 @@ def jax_funcify_Join(op): ...@@ -656,6 +713,44 @@ def jax_funcify_Join(op):
return join return join
@jax_funcify.register(MaxAndArgmax)
def jax_funcify_MaxAndArgmax(op):
axis = op.axis
def maxandargmax(x, axis=axis):
if axis is None:
axes = tuple(range(x.ndim))
else:
axes = tuple(int(ax) for ax in axis)
max_res = jnp.max(x, axis)
# NumPy does not support multiple axes for argmax; this is a
# work-around
keep_axes = jnp.array(
[i for i in range(x.ndim) if i not in axes], dtype="int64"
)
# Not-reduced axes in front
transposed_x = jnp.transpose(
x, jnp.concatenate((keep_axes, jnp.array(axes, dtype="int64")))
)
kept_shape = transposed_x.shape[: len(keep_axes)]
reduced_shape = transposed_x.shape[len(keep_axes) :]
# Numpy.prod returns 1.0 when arg is empty, so we cast it to int64
# Otherwise reshape would complain citing float arg
new_shape = kept_shape + (
jnp.prod(jnp.array(reduced_shape, dtype="int64"), dtype="int64"),
)
reshaped_x = transposed_x.reshape(new_shape)
max_idx_res = jnp.argmax(reshaped_x, axis=-1).astype("int64")
return max_res, max_idx_res
return maxandargmax
@jax_funcify.register(ExtractDiag) @jax_funcify.register(ExtractDiag)
def jax_funcify_ExtractDiag(op): def jax_funcify_ExtractDiag(op):
offset = op.offset offset = op.offset
...@@ -763,3 +858,141 @@ def jax_funcify_SVD(op): ...@@ -763,3 +858,141 @@ def jax_funcify_SVD(op):
return jnp.linalg.svd(x, full_matrices=full_matrices, compute_uv=compute_uv) return jnp.linalg.svd(x, full_matrices=full_matrices, compute_uv=compute_uv)
return svd return svd
@jax_funcify.register(CumOp)
def jax_funcify_CumOp(op):
axis = op.axis
mode = op.mode
def cumop(x, axis=axis, mode=mode):
if mode == "add":
return jnp.cumsum(x, axis=axis)
else:
return jnp.cumprod(x, axis=axis)
return cumop
@jax_funcify.register(DiffOp)
def jax_funcify_DiffOp(op):
n = op.n
axis = op.axis
def diffop(x, n=n, axis=axis):
return jnp.diff(x, n=n, axis=axis)
return diffop
@jax_funcify.register(RepeatOp)
def jax_funcify_RepeatOp(op):
axis = op.axis
def repeatop(x, repeats, axis=axis):
return jnp.repeat(x, repeats, axis=axis)
return repeatop
@jax_funcify.register(Bartlett)
def jax_funcify_Bartlett(op):
def bartlett(x):
return jnp.bartlett(x)
return bartlett
@jax_funcify.register(FillDiagonal)
def jax_funcify_FillDiagonal(op):
# def filldiagonal(a, val):
# if a.ndim == 2:
# step = a.shape[1] + 1
# end = a.shape[1] * a.shape[1]
# a.flat[:end:step] = val
# else:
# jnp.fill_diagonal(a, val)
#
# return a
#
# return filldiagonal
raise NotImplementedError("flatiter not implemented in JAX")
@jax_funcify.register(FillDiagonalOffset)
def jax_funcify_FillDiagonalOffset(op):
# def filldiagonaloffset(a, val, offset):
# height, width = a.shape
#
# if offset >= 0:
# start = offset
# num_of_step = min(min(width, height), width - offset)
# else:
# start = -offset * a.shape[1]
# num_of_step = min(min(width, height), height + offset)
#
# step = a.shape[1] + 1
# end = start + step * num_of_step
# a.flat[start:end:step] = val
#
# return a
#
# return filldiagonaloffset
raise NotImplementedError("flatiter not implemented in JAX")
@jax_funcify.register(Unique)
def jax_funcify_Unique(op):
return_index = op.return_index
return_inverse = op.return_inverse
return_counts = op.return_counts
axis = op.axis
def unique(
x,
return_index=return_index,
return_inverse=return_inverse,
return_counts=return_counts,
axis=axis,
):
param = {}
if return_index:
param["return_index"] = True
if return_inverse:
param["return_inverse"] = True
if return_counts:
param["return_counts"] = True
if axis is not None:
param["axis"] = axis
return jnp.unique(x, **param)
return unique
@jax_funcify.register(UnravelIndex)
def jax_funcify_UnravelIndex(op):
order = op.order
warn("JAX ignores the `order` parameter in `unravel_index`.")
def unravelindex(indices, dims, order=order):
return jnp.unravel_index(indices, dims)
return unravelindex
@jax_funcify.register(RavelMultiIndex)
def jax_funcify_RavelMultiIndex(op):
mode = op.mode
order = op.order
def ravelmultiindex(*inp, mode=mode, order=order):
multi_index, dims = inp[:-1], inp[-1]
return jnp.ravel_multi_index(multi_index, dims, mode=mode, order=order)
return ravelmultiindex
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论