提交 27c48a87 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Make jax.numpy.unique tests with symbolic input expected fails

上级 cd803dee
......@@ -484,21 +484,27 @@ def test_jax_MakeVector():
def test_jax_Reshape():
a_tt = tt.vector("a")
x = tt.basic.reshape(a_tt, (2, 2))
x_fg = theano.gof.FunctionGraph([a_tt], [x])
a = tt.vector("a")
x = tt.basic.reshape(a, (2, 2))
x_fg = theano.gof.FunctionGraph([a], [x])
compare_jax_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(theano.config.floatX)])
def test_jax_Reshape_omnistaging():
# Test breaking "omnistaging" changes in JAX.
# See https://github.com/tensorflow/probability/commit/782d0c64eb774b9aac54a1c8488e4f1f96fbbc68
a_tt = tt.vector("a")
x = tt.basic.reshape(a_tt, (a_tt.shape[0] // 2, a_tt.shape[0] // 3))
x_fg = theano.gof.FunctionGraph([a_tt], [x])
x = tt.basic.reshape(a, (a.shape[0] // 2, a.shape[0] // 2))
x_fg = theano.gof.FunctionGraph([a], [x])
compare_jax_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(theano.config.floatX)])
compare_jax_and_py(x_fg, [np.empty((6,)).astype(theano.config.floatX)])
@pytest.mark.xfail(reason="jax.numpy.arange requires concrete inputs")
def test_jax_Reshape_nonconcrete():
a = tt.vector("a")
b = tt.iscalar("b")
x = tt.basic.reshape(a, (b, b))
x_fg = theano.gof.FunctionGraph([a, b], [x])
compare_jax_and_py(
x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(theano.config.floatX), 2]
)
def test_jax_Dimshuffle():
......@@ -623,7 +629,8 @@ def test_tensor_basics():
@pytest.mark.xfail(reason="jax.numpy.arange requires concrete inputs")
def test_arange():
def test_arange_nonconcrete():
a = tt.scalar("a")
a.tag.test_value = 10
......@@ -632,6 +639,16 @@ def test_arange():
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
@pytest.mark.xfail(reason="jax.numpy.arange requires concrete inputs")
def test_unique_nonconcrete():
a = tt.matrix("a")
a.tag.test_value = np.arange(6, dtype=theano.config.floatX).reshape((3, 2))
out = tt.extra_ops.Unique()(a)
fgraph = theano.gof.FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_identity():
a = tt.scalar("a")
a.tag.test_value = 10
......@@ -720,3 +737,10 @@ def test_extra_ops():
compare_jax_and_py(
fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False
)
# The inputs are "concrete", yet it still has problems?
out = tt.extra_ops.Unique()(
tt.as_tensor(np.arange(6, dtype=theano.config.floatX).reshape((3, 2)))
)
fgraph = theano.gof.FunctionGraph([], [out])
compare_jax_and_py(fgraph, [])
......@@ -947,10 +947,16 @@ def jax_funcify_FillDiagonalOffset(op):
@jax_funcify.register(Unique)
def jax_funcify_Unique(op):
axis = op.axis
if axis is not None:
raise NotImplementedError(
"jax.numpy.unique is not implemented for the axis argument"
)
return_index = op.return_index
return_inverse = op.return_inverse
return_counts = op.return_counts
axis = op.axis
def unique(
x,
......@@ -959,17 +965,11 @@ def jax_funcify_Unique(op):
return_counts=return_counts,
axis=axis,
):
param = {}
if return_index:
param["return_index"] = True
if return_inverse:
param["return_inverse"] = True
if return_counts:
param["return_counts"] = True
if axis is not None:
param["axis"] = axis
return jnp.unique(x, **param)
ret = jnp.lax_numpy._unique1d(x, return_index, return_inverse, return_counts)
if len(ret) == 1:
return ret[0]
else:
return ret
return unique
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论