Unverified 提交 8924d294 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: GitHub

Merge pull request #80 from brandonwillard/add-cumop-jaxification

Add JAX conversions for theano.tensor.extra_ops
...@@ -19,17 +19,31 @@ def set_theano_flags(): ...@@ -19,17 +19,31 @@ def set_theano_flags():
def compare_jax_and_py( def compare_jax_and_py(
fgraph, inputs, assert_fn=partial(np.testing.assert_allclose, rtol=1e-4) fgraph,
inputs,
assert_fn=partial(np.testing.assert_allclose, rtol=1e-4),
simplify=False,
must_be_device_array=True,
): ):
theano_jax_fn = theano.function(fgraph.inputs, fgraph.outputs, mode="JAX") if not simplify:
opts = theano.gof.Query(include=[None], exclude=["cxx_only", "BlasOpt"])
jax_mode = theano.compile.mode.Mode(theano.sandbox.jax_linker.JAXLinker(), opts)
py_mode = theano.compile.Mode("py", opts)
else:
py_mode = theano.compile.Mode(linker="py")
jax_mode = "JAX"
theano_jax_fn = theano.function(fgraph.inputs, fgraph.outputs, mode=jax_mode)
jax_res = theano_jax_fn(*inputs) jax_res = theano_jax_fn(*inputs)
if isinstance(jax_res, list): if must_be_device_array:
assert all(isinstance(res, jax.interpreters.xla.DeviceArray) for res in jax_res) if isinstance(jax_res, list):
else: assert all(
assert isinstance(jax_res, jax.interpreters.xla.DeviceArray) isinstance(res, jax.interpreters.xla.DeviceArray) for res in jax_res
)
else:
assert isinstance(jax_res, jax.interpreters.xla.DeviceArray)
py_mode = theano.compile.Mode(linker="py")
theano_py_fn = theano.function(fgraph.inputs, fgraph.outputs, mode=py_mode) theano_py_fn = theano.function(fgraph.inputs, fgraph.outputs, mode=py_mode)
py_res = theano_py_fn(*inputs) py_res = theano_py_fn(*inputs)
...@@ -88,12 +102,12 @@ def test_jax_compile_ops(): ...@@ -88,12 +102,12 @@ def test_jax_compile_ops():
x = theano.compile.ops.Shape()(tt.as_tensor_variable(x_np)) x = theano.compile.ops.Shape()(tt.as_tensor_variable(x_np))
x_fg = theano.gof.FunctionGraph([], [x]) x_fg = theano.gof.FunctionGraph([], [x])
compare_jax_and_py(x_fg, []) compare_jax_and_py(x_fg, [], must_be_device_array=False)
x = theano.compile.ops.Shape_i(1)(tt.as_tensor_variable(x_np)) x = theano.compile.ops.Shape_i(1)(tt.as_tensor_variable(x_np))
x_fg = theano.gof.FunctionGraph([], [x]) x_fg = theano.gof.FunctionGraph([], [x])
compare_jax_and_py(x_fg, []) compare_jax_and_py(x_fg, [], must_be_device_array=False)
x = theano.compile.ops.SpecifyShape()(tt.as_tensor_variable(x_np), (20, 3)) x = theano.compile.ops.SpecifyShape()(tt.as_tensor_variable(x_np), (20, 3))
x_fg = theano.gof.FunctionGraph([], [x]) x_fg = theano.gof.FunctionGraph([], [x])
...@@ -340,7 +354,7 @@ def test_jax_Subtensors(): ...@@ -340,7 +354,7 @@ def test_jax_Subtensors():
def test_jax_IncSubtensor(): def test_jax_IncSubtensor():
x_np = np.empty((3, 4, 5), dtype=tt.config.floatX) x_np = np.random.uniform(-1, 1, size=(3, 4, 5)).astype(tt.config.floatX)
x_tt = tt.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(tt.config.floatX) x_tt = tt.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(tt.config.floatX)
# "Set" basic indices # "Set" basic indices
...@@ -410,6 +424,8 @@ def test_jax_IncSubtensor(): ...@@ -410,6 +424,8 @@ def test_jax_IncSubtensor():
def test_jax_ifelse(): def test_jax_ifelse():
import theano.ifelse
true_vals = np.r_[1, 2, 3] true_vals = np.r_[1, 2, 3]
false_vals = np.r_[-1, -2, -3] false_vals = np.r_[-1, -2, -3]
...@@ -648,3 +664,59 @@ def test_shared(): ...@@ -648,3 +664,59 @@ def test_shared():
jax_res = theano_jax_fn() jax_res = theano_jax_fn()
assert isinstance(jax_res, jax.interpreters.xla.DeviceArray) assert isinstance(jax_res, jax.interpreters.xla.DeviceArray)
np.testing.assert_allclose(jax_res, new_a_value * 2) np.testing.assert_allclose(jax_res, new_a_value * 2)
def test_extra_ops():
a = tt.matrix("a")
a.tag.test_value = np.arange(6, dtype=theano.config.floatX).reshape((3, 2))
out = tt.extra_ops.cumsum(a, axis=0)
fgraph = theano.gof.FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = tt.extra_ops.cumprod(a, axis=1)
fgraph = theano.gof.FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = tt.extra_ops.diff(a, n=2, axis=1)
fgraph = theano.gof.FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = tt.extra_ops.repeat(a, (3, 3), axis=1)
fgraph = theano.gof.FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
# This function also cannot take symbolic input.
c = tt.as_tensor(5)
out = tt.extra_ops.bartlett(c)
fgraph = theano.gof.FunctionGraph([], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
with pytest.raises(NotImplementedError):
out = tt.extra_ops.fill_diagonal(a, c)
fgraph = theano.gof.FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
with pytest.raises(NotImplementedError):
out = tt.extra_ops.fill_diagonal_offset(a, c, c)
fgraph = theano.gof.FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
with pytest.raises(NotImplementedError):
out = tt.extra_ops.Unique(axis=1)(a)
fgraph = theano.gof.FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
indices = np.arange(np.product((3, 4)))
out = tt.extra_ops.unravel_index(indices, (3, 4), order="C")
fgraph = theano.gof.FunctionGraph([], out)
compare_jax_and_py(
fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False
)
multi_index = np.unravel_index(np.arange(np.product((3, 4))), (3, 4))
out = tt.extra_ops.ravel_multi_index(multi_index, (3, 4))
fgraph = theano.gof.FunctionGraph([], [out])
compare_jax_and_py(
fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False
)
...@@ -1065,13 +1065,10 @@ class TestUnravelIndex(utt.InferShapeTester): ...@@ -1065,13 +1065,10 @@ class TestUnravelIndex(utt.InferShapeTester):
indices_symb = theano.shared(indices) indices_symb = theano.shared(indices)
# reference result # reference result
ref = np.unravel_index(indices, shape) ref = np.unravel_index(indices, shape, order=order)
def fn(i, d, nd=None): def fn(i, d):
if nd is None: return function([], unravel_index(i, d, order=order))
return function([], unravel_index(i, d, order=order))
else:
return function([], unravel_index(i, d, order=order, ndim=nd))
# shape given as a tuple # shape given as a tuple
f_array_tuple = fn(indices, shape) f_array_tuple = fn(indices, shape)
...@@ -1086,7 +1083,7 @@ class TestUnravelIndex(utt.InferShapeTester): ...@@ -1086,7 +1083,7 @@ class TestUnravelIndex(utt.InferShapeTester):
# shape given as a theano variable # shape given as a theano variable
shape_symb = theano.shared(shape_array) shape_symb = theano.shared(shape_array)
f_array_symb = fn(indices, shape_symb, len(shape)) f_array_symb = fn(indices, shape_symb)
np.testing.assert_equal(ref, f_array_symb()) np.testing.assert_equal(ref, f_array_symb())
# shape given as a Shape op (unravel_index will use get_vector_length # shape given as a Shape op (unravel_index will use get_vector_length
...@@ -1098,7 +1095,7 @@ class TestUnravelIndex(utt.InferShapeTester): ...@@ -1098,7 +1095,7 @@ class TestUnravelIndex(utt.InferShapeTester):
# shape testing # shape testing
self._compile_and_check( self._compile_and_check(
[], [],
unravel_index(indices, shape_symb, order=order, ndim=len(shape)), unravel_index(indices, shape_symb, order=order),
[], [],
UnravelIndex, UnravelIndex,
) )
...@@ -1118,8 +1115,6 @@ class TestUnravelIndex(utt.InferShapeTester): ...@@ -1118,8 +1115,6 @@ class TestUnravelIndex(utt.InferShapeTester):
unravel_index(theano.tensor.fvector(), (3, 4)) unravel_index(theano.tensor.fvector(), (3, 4))
with pytest.raises(TypeError): with pytest.raises(TypeError):
unravel_index((3, 4), (3.4, 3.2)) unravel_index((3, 4), (3.4, 3.2))
with pytest.raises(ValueError):
unravel_index((3, 4), (3, 3), ndim=5.4)
# dims must be a 1D sequence # dims must be a 1D sequence
with pytest.raises(TypeError): with pytest.raises(TypeError):
......
...@@ -35,6 +35,7 @@ from theano.tensor.basic import ( ...@@ -35,6 +35,7 @@ from theano.tensor.basic import (
Alloc, Alloc,
Reshape, Reshape,
Join, Join,
MaxAndArgmax,
) )
from theano.scalar.basic import ScalarOp, Composite, Cast, Clip, Identity from theano.scalar.basic import ScalarOp, Composite, Cast, Clip, Identity
from theano.tensor.elemwise import Elemwise, CAReduce, DimShuffle from theano.tensor.elemwise import Elemwise, CAReduce, DimShuffle
...@@ -67,6 +68,21 @@ from theano.tensor.slinalg import ( ...@@ -67,6 +68,21 @@ from theano.tensor.slinalg import (
Solve, Solve,
) )
from theano.tensor.type_other import MakeSlice
from theano.tensor.extra_ops import (
CumOp,
DiffOp,
RepeatOp,
Bartlett,
FillDiagonal,
FillDiagonalOffset,
Unique,
UnravelIndex,
RavelMultiIndex,
)
if theano.config.floatX == "float64": if theano.config.floatX == "float64":
jax.config.update("jax_enable_x64", True) jax.config.update("jax_enable_x64", True)
else: else:
...@@ -82,7 +98,7 @@ except AttributeError: ...@@ -82,7 +98,7 @@ except AttributeError:
pass pass
subtensor_ops = (Subtensor, AdvancedSubtensor1, BaseAdvancedSubtensor) subtensor_ops = (Subtensor, AdvancedSubtensor1, BaseAdvancedSubtensor)
incsubtensor_ops = (IncSubtensor, AdvancedIncSubtensor1, BaseAdvancedIncSubtensor) incsubtensor_ops = (IncSubtensor, AdvancedIncSubtensor1)
def compose_jax_funcs(out_node, fgraph_inputs, memo=None): def compose_jax_funcs(out_node, fgraph_inputs, memo=None):
...@@ -116,15 +132,23 @@ def compose_jax_funcs(out_node, fgraph_inputs, memo=None): ...@@ -116,15 +132,23 @@ def compose_jax_funcs(out_node, fgraph_inputs, memo=None):
if i in fgraph_inputs: if i in fgraph_inputs:
idx = fgraph_inputs.index(i) idx = fgraph_inputs.index(i)
def jax_inputs_func(*inputs, i_dtype=i.dtype, idx=idx): i_dtype = getattr(i, "dtype", None)
def jax_inputs_func(*inputs, i_dtype=i_dtype, idx=idx):
return jnp.array(inputs[idx], dtype=jnp.dtype(i_dtype)) return jnp.array(inputs[idx], dtype=jnp.dtype(i_dtype))
input_f = jax_inputs_func input_f = jax_inputs_func
elif i.owner is None: elif i.owner is None:
def jax_data_func(*inputs, i_dtype=i.dtype, i_data=i.data): i_dtype = getattr(i, "dtype", None)
return jnp.array(i_data, dtype=jnp.dtype(i_dtype)) i_data = i.data
def jax_data_func(*inputs, i_dtype=i_dtype, i_data=i_data):
if i_dtype is None:
return i_data
else:
return jnp.array(i_data, dtype=jnp.dtype(i_dtype))
input_f = jax_data_func input_f = jax_data_func
else: else:
...@@ -158,6 +182,14 @@ def jax_funcify(op): ...@@ -158,6 +182,14 @@ def jax_funcify(op):
raise NotImplementedError("No JAX conversion for the given `Op`: {}".format(op)) raise NotImplementedError("No JAX conversion for the given `Op`: {}".format(op))
@jax_funcify.register(MakeSlice)
def jax_funcify_MakeSlice(op):
def makeslice(*x):
return slice(*x)
return makeslice
@jax_funcify.register(ScalarOp) @jax_funcify.register(ScalarOp)
def jax_funcify_ScalarOp(op): def jax_funcify_ScalarOp(op):
func_name = op.nfunc_spec[0] func_name = op.nfunc_spec[0]
...@@ -288,8 +320,13 @@ def jax_funcify_Shape_i(op): ...@@ -288,8 +320,13 @@ def jax_funcify_Shape_i(op):
@jax_funcify.register(SpecifyShape) @jax_funcify.register(SpecifyShape)
def jax_funcify_SpecifyShape(op): def jax_funcify_SpecifyShape(op):
def specifyshape(x, shape): def specifyshape(x, shape):
assert x.ndim == shape.size assert x.ndim == len(shape)
assert jnp.all(x.shape == shape), ("got shape", x.shape, "expected", shape) assert jnp.all(x.shape == tuple(shape)), (
"got shape",
x.shape,
"expected",
shape,
)
return x return x
return specifyshape return specifyshape
...@@ -475,11 +512,15 @@ def jax_funcify_Scan(op): ...@@ -475,11 +512,15 @@ def jax_funcify_Scan(op):
@jax_funcify.register(IfElse) @jax_funcify.register(IfElse)
def jax_funcify_IfElse(op): def jax_funcify_IfElse(op):
def ifelse(cond, *args): n_outs = op.n_outs
def ifelse(cond, *args, n_outs=n_outs):
if cond: if cond:
return args[: op.n_outs] res = args[:n_outs]
else: else:
return args[op.n_outs :] res = args[n_outs:]
return res if n_outs > 1 else res[0]
return ifelse return ifelse
...@@ -526,14 +567,16 @@ _ = [jax_funcify.register(op, jax_funcify_Subtensor) for op in subtensor_ops] ...@@ -526,14 +567,16 @@ _ = [jax_funcify.register(op, jax_funcify_Subtensor) for op in subtensor_ops]
def jax_funcify_IncSubtensor(op): def jax_funcify_IncSubtensor(op):
idx_list = op.idx_list
if getattr(op, "set_instead_of_inc", False): if getattr(op, "set_instead_of_inc", False):
jax_fn = jax.ops.index_update jax_fn = jax.ops.index_update
else: else:
jax_fn = jax.ops.index_add jax_fn = jax.ops.index_add
def incsubtensor(x, y, *ilist, jax_fn=jax_fn): def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list):
_ilist = list(ilist) _ilist = list(ilist)
cdata = tuple(convert_indices(_ilist, idx) for idx in op.idx_list) cdata = tuple(convert_indices(_ilist, idx) for idx in idx_list)
if len(cdata) == 1: if len(cdata) == 1:
cdata = cdata[0] cdata = cdata[0]
...@@ -545,6 +588,20 @@ def jax_funcify_IncSubtensor(op): ...@@ -545,6 +588,20 @@ def jax_funcify_IncSubtensor(op):
_ = [jax_funcify.register(op, jax_funcify_IncSubtensor) for op in incsubtensor_ops] _ = [jax_funcify.register(op, jax_funcify_IncSubtensor) for op in incsubtensor_ops]
@jax_funcify.register(BaseAdvancedIncSubtensor)
def jax_funcify_BaseAdvancedIncSubtensor(op):
if getattr(op, "set_instead_of_inc", False):
jax_fn = jax.ops.index_update
else:
jax_fn = jax.ops.index_add
def baseadvancedincsubtensor(x, y, *ilist, jax_fn=jax_fn):
return jax_fn(x, ilist, y)
return baseadvancedincsubtensor
@jax_funcify.register(FunctionGraph) @jax_funcify.register(FunctionGraph)
def jax_funcify_FunctionGraph(fgraph): def jax_funcify_FunctionGraph(fgraph):
...@@ -656,6 +713,44 @@ def jax_funcify_Join(op): ...@@ -656,6 +713,44 @@ def jax_funcify_Join(op):
return join return join
@jax_funcify.register(MaxAndArgmax)
def jax_funcify_MaxAndArgmax(op):
axis = op.axis
def maxandargmax(x, axis=axis):
if axis is None:
axes = tuple(range(x.ndim))
else:
axes = tuple(int(ax) for ax in axis)
max_res = jnp.max(x, axis)
# NumPy does not support multiple axes for argmax; this is a
# work-around
keep_axes = jnp.array(
[i for i in range(x.ndim) if i not in axes], dtype="int64"
)
# Not-reduced axes in front
transposed_x = jnp.transpose(
x, jnp.concatenate((keep_axes, jnp.array(axes, dtype="int64")))
)
kept_shape = transposed_x.shape[: len(keep_axes)]
reduced_shape = transposed_x.shape[len(keep_axes) :]
# Numpy.prod returns 1.0 when arg is empty, so we cast it to int64
# Otherwise reshape would complain citing float arg
new_shape = kept_shape + (
jnp.prod(jnp.array(reduced_shape, dtype="int64"), dtype="int64"),
)
reshaped_x = transposed_x.reshape(new_shape)
max_idx_res = jnp.argmax(reshaped_x, axis=-1).astype("int64")
return max_res, max_idx_res
return maxandargmax
@jax_funcify.register(ExtractDiag) @jax_funcify.register(ExtractDiag)
def jax_funcify_ExtractDiag(op): def jax_funcify_ExtractDiag(op):
offset = op.offset offset = op.offset
...@@ -763,3 +858,141 @@ def jax_funcify_SVD(op): ...@@ -763,3 +858,141 @@ def jax_funcify_SVD(op):
return jnp.linalg.svd(x, full_matrices=full_matrices, compute_uv=compute_uv) return jnp.linalg.svd(x, full_matrices=full_matrices, compute_uv=compute_uv)
return svd return svd
@jax_funcify.register(CumOp)
def jax_funcify_CumOp(op):
axis = op.axis
mode = op.mode
def cumop(x, axis=axis, mode=mode):
if mode == "add":
return jnp.cumsum(x, axis=axis)
else:
return jnp.cumprod(x, axis=axis)
return cumop
@jax_funcify.register(DiffOp)
def jax_funcify_DiffOp(op):
n = op.n
axis = op.axis
def diffop(x, n=n, axis=axis):
return jnp.diff(x, n=n, axis=axis)
return diffop
@jax_funcify.register(RepeatOp)
def jax_funcify_RepeatOp(op):
axis = op.axis
def repeatop(x, repeats, axis=axis):
return jnp.repeat(x, repeats, axis=axis)
return repeatop
@jax_funcify.register(Bartlett)
def jax_funcify_Bartlett(op):
def bartlett(x):
return jnp.bartlett(x)
return bartlett
@jax_funcify.register(FillDiagonal)
def jax_funcify_FillDiagonal(op):
# def filldiagonal(a, val):
# if a.ndim == 2:
# step = a.shape[1] + 1
# end = a.shape[1] * a.shape[1]
# a.flat[:end:step] = val
# else:
# jnp.fill_diagonal(a, val)
#
# return a
#
# return filldiagonal
raise NotImplementedError("flatiter not implemented in JAX")
@jax_funcify.register(FillDiagonalOffset)
def jax_funcify_FillDiagonalOffset(op):
# def filldiagonaloffset(a, val, offset):
# height, width = a.shape
#
# if offset >= 0:
# start = offset
# num_of_step = min(min(width, height), width - offset)
# else:
# start = -offset * a.shape[1]
# num_of_step = min(min(width, height), height + offset)
#
# step = a.shape[1] + 1
# end = start + step * num_of_step
# a.flat[start:end:step] = val
#
# return a
#
# return filldiagonaloffset
raise NotImplementedError("flatiter not implemented in JAX")
@jax_funcify.register(Unique)
def jax_funcify_Unique(op):
return_index = op.return_index
return_inverse = op.return_inverse
return_counts = op.return_counts
axis = op.axis
def unique(
x,
return_index=return_index,
return_inverse=return_inverse,
return_counts=return_counts,
axis=axis,
):
param = {}
if return_index:
param["return_index"] = True
if return_inverse:
param["return_inverse"] = True
if return_counts:
param["return_counts"] = True
if axis is not None:
param["axis"] = axis
return jnp.unique(x, **param)
return unique
@jax_funcify.register(UnravelIndex)
def jax_funcify_UnravelIndex(op):
order = op.order
warn("JAX ignores the `order` parameter in `unravel_index`.")
def unravelindex(indices, dims, order=order):
return jnp.unravel_index(indices, dims)
return unravelindex
@jax_funcify.register(RavelMultiIndex)
def jax_funcify_RavelMultiIndex(op):
mode = op.mode
order = op.order
def ravelmultiindex(*inp, mode=mode, order=order):
multi_index, dims = inp[:-1], inp[-1]
return jnp.ravel_multi_index(multi_index, dims, mode=mode, order=order)
return ravelmultiindex
...@@ -5002,6 +5002,8 @@ def get_vector_length(v): ...@@ -5002,6 +5002,8 @@ def get_vector_length(v):
raise TypeError("argument must be symbolic vector, got '%s'" % v) raise TypeError("argument must be symbolic vector, got '%s'" % v)
if v.type.broadcastable[0]: if v.type.broadcastable[0]:
return 1 return 1
if isinstance(v, theano.tensor.sharedvar.TensorSharedVariable) and v.type.ndim == 1:
return len(v.get_value())
if isinstance(v, gof.Constant) and v.type.ndim == 1: if isinstance(v, gof.Constant) and v.type.ndim == 1:
return len(v.data) return len(v.data)
if v.owner and isinstance(v.owner.op, theano.tensor.opt.MakeVector): if v.owner and isinstance(v.owner.op, theano.tensor.opt.MakeVector):
......
...@@ -294,7 +294,10 @@ class CumOp(theano.Op): ...@@ -294,7 +294,10 @@ class CumOp(theano.Op):
def perform(self, node, inputs, output_storage, params): def perform(self, node, inputs, output_storage, params):
x = inputs[0] x = inputs[0]
z = output_storage[0] z = output_storage[0]
z[0] = {"add": np.cumsum, "mul": np.cumprod}[self.mode](x, axis=self.axis) if self.mode == "add":
z[0] = np.cumsum(x, axis=self.axis)
else:
z[0] = np.cumprod(x, axis=self.axis)
def grad(self, inputs, output_gradients): def grad(self, inputs, output_gradients):
(x,) = inputs (x,) = inputs
...@@ -1289,13 +1292,10 @@ class Unique(theano.Op): ...@@ -1289,13 +1292,10 @@ class Unique(theano.Op):
class UnravelIndex(gof.Op): class UnravelIndex(gof.Op):
__props__ = ("ndim", "order") __props__ = ("order",)
def __init__(self, ndim, order="C"): def __init__(self, order="C"):
assert order in ("C", "F") assert order in ("C", "F")
if not isinstance(ndim, int) or ndim < 1:
raise ValueError("ndim must be an integer greater than 0")
self.ndim = int(ndim)
self.order = order self.order = order
def make_node(self, indices, dims): def make_node(self, indices, dims):
...@@ -1318,7 +1318,7 @@ class UnravelIndex(gof.Op): ...@@ -1318,7 +1318,7 @@ class UnravelIndex(gof.Op):
[indices, dims], [indices, dims],
[ [
basic.TensorType(dtype="int64", broadcastable=(False,) * indices.ndim)() basic.TensorType(dtype="int64", broadcastable=(False,) * indices.ndim)()
for i in range(self.ndim) for i in range(basic.get_vector_length(dims))
], ],
) )
...@@ -1327,7 +1327,7 @@ class UnravelIndex(gof.Op): ...@@ -1327,7 +1327,7 @@ class UnravelIndex(gof.Op):
def perform(self, node, inp, out): def perform(self, node, inp, out):
indices, dims = inp indices, dims = inp
res = np.unravel_index(indices, dims) res = np.unravel_index(indices, dims, order=self.order)
assert len(res) == len(out) assert len(res) == len(out)
for i in range(len(out)): for i in range(len(out)):
ret = theano._asarray(res[i], node.outputs[0].dtype) ret = theano._asarray(res[i], node.outputs[0].dtype)
...@@ -1338,15 +1338,11 @@ class UnravelIndex(gof.Op): ...@@ -1338,15 +1338,11 @@ class UnravelIndex(gof.Op):
out[i][0] = ret out[i][0] = ret
def unravel_index(indices, dims, order="C", ndim=None): def unravel_index(indices, dims, order="C"):
""" """
Converts a flat index or array of flat indices into a tuple Converts a flat index or array of flat indices into a tuple
of coordinate arrays. of coordinate arrays.
This method is similar to the NumPy version, except for the
additional ``ndim`` parameter. This parameter is required if
the length of ``dims`` cannot be determined automatically.
Parameters Parameters
---------- ----------
indices : Theano or NumPy array indices : Theano or NumPy array
...@@ -1357,10 +1353,6 @@ def unravel_index(indices, dims, order="C", ndim=None): ...@@ -1357,10 +1353,6 @@ def unravel_index(indices, dims, order="C", ndim=None):
order : {'C', 'F'}, optional order : {'C', 'F'}, optional
Determines whether the indices should be viewed as indexing in Determines whether the indices should be viewed as indexing in
row-major (C-style) or column-major (Fortran-style) order. row-major (C-style) or column-major (Fortran-style) order.
ndim : int, optional
Specifies the number of dimensions, i.e., the length of
``dims``. This is required if the dimensions cannot be determined
automatically from ``dims`` itself.
Returns Returns
------- -------
...@@ -1373,20 +1365,8 @@ def unravel_index(indices, dims, order="C", ndim=None): ...@@ -1373,20 +1365,8 @@ def unravel_index(indices, dims, order="C", ndim=None):
ravel_multi_index ravel_multi_index
""" """
if ndim is None: res = UnravelIndex(order=order)(indices, dims)
try: if not isinstance(res, (list, tuple)):
ndim = basic.get_vector_length(dims)
except ValueError:
raise ValueError(
"The length of the provided dimension list (%s) cannot "
"be automatically determined, so Theano is not able "
"to know what the number of dimensions of the unraveled "
"index will be. You can provide the 'ndim' keyword "
"argument to 'unravel_index' to avoid this problem." % str(dims)
)
res = UnravelIndex(ndim=ndim, order=order)(indices, dims)
if ndim == 1:
return (res,) return (res,)
else: else:
return tuple(res) return tuple(res)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论