提交 27c48a87 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Make jax.numpy.unique tests with symbolic input expected fails

上级 cd803dee
...@@ -484,21 +484,27 @@ def test_jax_MakeVector(): ...@@ -484,21 +484,27 @@ def test_jax_MakeVector():
def test_jax_Reshape(): def test_jax_Reshape():
a_tt = tt.vector("a") a = tt.vector("a")
x = tt.basic.reshape(a_tt, (2, 2)) x = tt.basic.reshape(a, (2, 2))
x_fg = theano.gof.FunctionGraph([a_tt], [x]) x_fg = theano.gof.FunctionGraph([a], [x])
compare_jax_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(theano.config.floatX)]) compare_jax_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(theano.config.floatX)])
def test_jax_Reshape_omnistaging():
# Test breaking "omnistaging" changes in JAX. # Test breaking "omnistaging" changes in JAX.
# See https://github.com/tensorflow/probability/commit/782d0c64eb774b9aac54a1c8488e4f1f96fbbc68 # See https://github.com/tensorflow/probability/commit/782d0c64eb774b9aac54a1c8488e4f1f96fbbc68
a_tt = tt.vector("a") x = tt.basic.reshape(a, (a.shape[0] // 2, a.shape[0] // 2))
x = tt.basic.reshape(a_tt, (a_tt.shape[0] // 2, a_tt.shape[0] // 3)) x_fg = theano.gof.FunctionGraph([a], [x])
x_fg = theano.gof.FunctionGraph([a_tt], [x]) compare_jax_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(theano.config.floatX)])
compare_jax_and_py(x_fg, [np.empty((6,)).astype(theano.config.floatX)])
@pytest.mark.xfail(reason="jax.numpy.arange requires concrete inputs")
def test_jax_Reshape_nonconcrete():
a = tt.vector("a")
b = tt.iscalar("b")
x = tt.basic.reshape(a, (b, b))
x_fg = theano.gof.FunctionGraph([a, b], [x])
compare_jax_and_py(
x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(theano.config.floatX), 2]
)
def test_jax_Dimshuffle(): def test_jax_Dimshuffle():
...@@ -623,7 +629,8 @@ def test_tensor_basics(): ...@@ -623,7 +629,8 @@ def test_tensor_basics():
@pytest.mark.xfail(reason="jax.numpy.arange requires concrete inputs") @pytest.mark.xfail(reason="jax.numpy.arange requires concrete inputs")
def test_arange(): def test_arange_nonconcrete():
a = tt.scalar("a") a = tt.scalar("a")
a.tag.test_value = 10 a.tag.test_value = 10
...@@ -632,6 +639,16 @@ def test_arange(): ...@@ -632,6 +639,16 @@ def test_arange():
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
@pytest.mark.xfail(reason="jax.numpy.arange requires concrete inputs")
def test_unique_nonconcrete():
a = tt.matrix("a")
a.tag.test_value = np.arange(6, dtype=theano.config.floatX).reshape((3, 2))
out = tt.extra_ops.Unique()(a)
fgraph = theano.gof.FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_identity(): def test_identity():
a = tt.scalar("a") a = tt.scalar("a")
a.tag.test_value = 10 a.tag.test_value = 10
...@@ -720,3 +737,10 @@ def test_extra_ops(): ...@@ -720,3 +737,10 @@ def test_extra_ops():
compare_jax_and_py( compare_jax_and_py(
fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False
) )
# The inputs are "concrete", yet it still has problems?
out = tt.extra_ops.Unique()(
tt.as_tensor(np.arange(6, dtype=theano.config.floatX).reshape((3, 2)))
)
fgraph = theano.gof.FunctionGraph([], [out])
compare_jax_and_py(fgraph, [])
...@@ -947,10 +947,16 @@ def jax_funcify_FillDiagonalOffset(op): ...@@ -947,10 +947,16 @@ def jax_funcify_FillDiagonalOffset(op):
@jax_funcify.register(Unique) @jax_funcify.register(Unique)
def jax_funcify_Unique(op): def jax_funcify_Unique(op):
axis = op.axis
if axis is not None:
raise NotImplementedError(
"jax.numpy.unique is not implemented for the axis argument"
)
return_index = op.return_index return_index = op.return_index
return_inverse = op.return_inverse return_inverse = op.return_inverse
return_counts = op.return_counts return_counts = op.return_counts
axis = op.axis
def unique( def unique(
x, x,
...@@ -959,17 +965,11 @@ def jax_funcify_Unique(op): ...@@ -959,17 +965,11 @@ def jax_funcify_Unique(op):
return_counts=return_counts, return_counts=return_counts,
axis=axis, axis=axis,
): ):
param = {} ret = jnp.lax_numpy._unique1d(x, return_index, return_inverse, return_counts)
if return_index: if len(ret) == 1:
param["return_index"] = True return ret[0]
if return_inverse: else:
param["return_inverse"] = True return ret
if return_counts:
param["return_counts"] = True
if axis is not None:
param["axis"] = axis
return jnp.unique(x, **param)
return unique return unique
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论