提交 bf752d81 authored 作者: junpenglao's avatar junpenglao 提交者: Brandon T. Willard

Change tt.config to theano.config in test_jax

上级 0ea34350
......@@ -111,7 +111,7 @@ def test_jax_Alloc():
x = tt.alloc(a, 20, 10)
x_fg = theano.gof.FunctionGraph([a], [x])
compare_jax_and_py(x_fg, [np.ones(10, dtype=tt.config.floatX)])
compare_jax_and_py(x_fg, [np.ones(10, dtype=theano.config.floatX)])
def test_jax_compile_ops():
......@@ -182,8 +182,8 @@ def test_jax_basic():
out_fg = theano.gof.FunctionGraph([x, y], [out])
test_input_vals = [
np.tile(np.arange(10), (10, 1)).astype(tt.config.floatX),
np.tile(np.arange(10, 20), (10, 1)).astype(tt.config.floatX),
np.tile(np.arange(10), (10, 1)).astype(theano.config.floatX),
np.tile(np.arange(10, 20), (10, 1)).astype(theano.config.floatX),
]
(jax_res,) = compare_jax_and_py(out_fg, test_input_vals)
......@@ -201,43 +201,49 @@ def test_jax_basic():
out = tt.diagonal(x, 0)
out_fg = theano.gof.FunctionGraph([x], [out])
compare_jax_and_py(
out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(tt.config.floatX)]
out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(theano.config.floatX)]
)
out = tt.slinalg.cholesky(x)
out_fg = theano.gof.FunctionGraph([x], [out])
compare_jax_and_py(
out_fg, [(np.eye(10) + np.random.randn(10, 10) * 0.01).astype(tt.config.floatX)]
out_fg,
[(np.eye(10) + np.random.randn(10, 10) * 0.01).astype(theano.config.floatX)],
)
# not sure why this isn't working yet with lower=False
out = tt.slinalg.Cholesky(lower=False)(x)
out_fg = theano.gof.FunctionGraph([x], [out])
compare_jax_and_py(
out_fg, [(np.eye(10) + np.random.randn(10, 10) * 0.01).astype(tt.config.floatX)]
out_fg,
[(np.eye(10) + np.random.randn(10, 10) * 0.01).astype(theano.config.floatX)],
)
out = tt.slinalg.solve(x, b)
out_fg = theano.gof.FunctionGraph([x, b], [out])
compare_jax_and_py(
out_fg,
[np.eye(10).astype(tt.config.floatX), np.arange(10).astype(tt.config.floatX)],
[
np.eye(10).astype(theano.config.floatX),
np.arange(10).astype(theano.config.floatX),
],
)
out = tt.nlinalg.alloc_diag(b)
out_fg = theano.gof.FunctionGraph([b], [out])
compare_jax_and_py(out_fg, [np.arange(10).astype(tt.config.floatX)])
compare_jax_and_py(out_fg, [np.arange(10).astype(theano.config.floatX)])
out = tt.nlinalg.det(x)
out_fg = theano.gof.FunctionGraph([x], [out])
compare_jax_and_py(
out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(tt.config.floatX)]
out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(theano.config.floatX)]
)
out = tt.nlinalg.matrix_inverse(x)
out_fg = theano.gof.FunctionGraph([x], [out])
compare_jax_and_py(
out_fg, [(np.eye(10) + np.random.randn(10, 10) * 0.01).astype(tt.config.floatX)]
out_fg,
[(np.eye(10) + np.random.randn(10, 10) * 0.01).astype(theano.config.floatX)],
)
......@@ -261,25 +267,25 @@ def test_jax_basic_multiout():
out_fg = theano.gof.FunctionGraph([x], outs)
def assert_fn(x, y):
np.testing.assert_allclose(x.astype(tt.config.floatX), y, rtol=1e-3)
np.testing.assert_allclose(x.astype(theano.config.floatX), y, rtol=1e-3)
compare_jax_and_py(out_fg, [X.astype(tt.config.floatX)], assert_fn=assert_fn)
compare_jax_and_py(out_fg, [X.astype(theano.config.floatX)], assert_fn=assert_fn)
outs = tt.nlinalg.eigh(x)
out_fg = theano.gof.FunctionGraph([x], outs)
compare_jax_and_py(out_fg, [X.astype(tt.config.floatX)], assert_fn=assert_fn)
compare_jax_and_py(out_fg, [X.astype(theano.config.floatX)], assert_fn=assert_fn)
outs = tt.nlinalg.qr(x, mode="full")
out_fg = theano.gof.FunctionGraph([x], outs)
compare_jax_and_py(out_fg, [X.astype(tt.config.floatX)], assert_fn=assert_fn)
compare_jax_and_py(out_fg, [X.astype(theano.config.floatX)], assert_fn=assert_fn)
outs = tt.nlinalg.qr(x, mode="reduced")
out_fg = theano.gof.FunctionGraph([x], outs)
compare_jax_and_py(out_fg, [X.astype(tt.config.floatX)], assert_fn=assert_fn)
compare_jax_and_py(out_fg, [X.astype(theano.config.floatX)], assert_fn=assert_fn)
outs = tt.nlinalg.svd(x)
out_fg = theano.gof.FunctionGraph([x], outs)
compare_jax_and_py(out_fg, [X.astype(tt.config.floatX)], assert_fn=assert_fn)
compare_jax_and_py(out_fg, [X.astype(theano.config.floatX)], assert_fn=assert_fn)
# Test that a single output of a multi-output `Op` can be used as input to
# another `Op`
......@@ -357,10 +363,11 @@ def test_jax_scan_multiple_output():
)
s0, e0, i0 = 100, 50, 25
logp_c0 = np.array(0.0).astype(tt.config.floatX)
logp_d0 = np.array(0.0).astype(tt.config.floatX)
logp_c0 = np.array(0.0, dtype=theano.config.floatX)
logp_d0 = np.array(0.0, dtype=theano.config.floatX)
beta_val, gamma_val, delta_val = [
np.array(val).astype(tt.config.floatX) for val in [0.277792, 0.135330, 0.108753]
np.array(val, dtype=theano.config.floatX)
for val in [0.277792, 0.135330, 0.108753]
]
C = np.array([3, 5, 8, 13, 21, 26, 10, 3], dtype=np.int32)
D = np.array([1, 2, 3, 7, 9, 11, 5, 1], dtype=np.int32)
......@@ -396,7 +403,7 @@ def test_jax_scan_tap_output():
outputs_info=[
{
"initial": tt.as_tensor_variable(
np.r_[-1.0, 1.3, 0.0].astype(tt.config.floatX)
np.r_[-1.0, 1.3, 0.0].astype(theano.config.floatX)
),
"taps": [-1, -3],
},
......@@ -410,7 +417,7 @@ def test_jax_scan_tap_output():
out_fg = theano.gof.FunctionGraph([a_tt], [y_scan_tt])
test_input_vals = [np.array(10.0).astype(tt.config.floatX)]
test_input_vals = [np.array(10.0).astype(theano.config.floatX)]
compare_jax_and_py(out_fg, test_input_vals)
......@@ -457,16 +464,16 @@ def test_jax_Subtensors():
def test_jax_IncSubtensor():
x_np = np.random.uniform(-1, 1, size=(3, 4, 5)).astype(tt.config.floatX)
x_tt = tt.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(tt.config.floatX)
x_np = np.random.uniform(-1, 1, size=(3, 4, 5)).astype(theano.config.floatX)
x_tt = tt.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(theano.config.floatX)
# "Set" basic indices
st_tt = tt.as_tensor_variable(np.array(-10.0, dtype=tt.config.floatX))
st_tt = tt.as_tensor_variable(np.array(-10.0, dtype=theano.config.floatX))
out_tt = tt.set_subtensor(x_tt[1, 2, 3], st_tt)
out_fg = theano.gof.FunctionGraph([], [out_tt])
compare_jax_and_py(out_fg, [])
st_tt = tt.as_tensor_variable(np.r_[-1.0, 0.0].astype(tt.config.floatX))
st_tt = tt.as_tensor_variable(np.r_[-1.0, 0.0].astype(theano.config.floatX))
out_tt = tt.set_subtensor(x_tt[:2, 0, 0], st_tt)
out_fg = theano.gof.FunctionGraph([], [out_tt])
compare_jax_and_py(out_fg, [])
......@@ -476,7 +483,7 @@ def test_jax_IncSubtensor():
compare_jax_and_py(out_fg, [])
# "Set" advanced indices
st_tt = tt.as_tensor_variable(np.r_[-1.0, 0.0].astype(tt.config.floatX))
st_tt = tt.as_tensor_variable(np.r_[-1.0, 0.0].astype(theano.config.floatX))
out_tt = tt.set_subtensor(x_tt[[0, 2], 0, 0], st_tt)
out_fg = theano.gof.FunctionGraph([], [out_tt])
compare_jax_and_py(out_fg, [])
......@@ -493,12 +500,12 @@ def test_jax_IncSubtensor():
compare_jax_and_py(out_fg, [])
# "Increment" basic indices
st_tt = tt.as_tensor_variable(np.array(-10.0, dtype=tt.config.floatX))
st_tt = tt.as_tensor_variable(np.array(-10.0, dtype=theano.config.floatX))
out_tt = tt.inc_subtensor(x_tt[1, 2, 3], st_tt)
out_fg = theano.gof.FunctionGraph([], [out_tt])
compare_jax_and_py(out_fg, [])
st_tt = tt.as_tensor_variable(np.r_[-1.0, 0.0].astype(tt.config.floatX))
st_tt = tt.as_tensor_variable(np.r_[-1.0, 0.0].astype(theano.config.floatX))
out_tt = tt.inc_subtensor(x_tt[:2, 0, 0], st_tt)
out_fg = theano.gof.FunctionGraph([], [out_tt])
compare_jax_and_py(out_fg, [])
......@@ -508,7 +515,7 @@ def test_jax_IncSubtensor():
compare_jax_and_py(out_fg, [])
# "Increment" advanced indices
st_tt = tt.as_tensor_variable(np.r_[-1.0, 0.0].astype(tt.config.floatX))
st_tt = tt.as_tensor_variable(np.r_[-1.0, 0.0].astype(theano.config.floatX))
out_tt = tt.inc_subtensor(x_tt[[0, 2], 0, 0], st_tt)
out_fg = theano.gof.FunctionGraph([], [out_tt])
compare_jax_and_py(out_fg, [])
......@@ -545,38 +552,38 @@ def test_jax_ifelse():
def test_jax_CAReduce():
a_tt = tt.vector("a")
a_tt.tag.test_value = np.r_[1, 2, 3].astype(tt.config.floatX)
a_tt.tag.test_value = np.r_[1, 2, 3].astype(theano.config.floatX)
x = tt.sum(a_tt, axis=None)
x_fg = theano.gof.FunctionGraph([a_tt], [x])
compare_jax_and_py(x_fg, [np.r_[1, 2, 3].astype(tt.config.floatX)])
compare_jax_and_py(x_fg, [np.r_[1, 2, 3].astype(theano.config.floatX)])
a_tt = tt.matrix("a")
a_tt.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(tt.config.floatX)
a_tt.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(theano.config.floatX)
x = tt.sum(a_tt, axis=0)
x_fg = theano.gof.FunctionGraph([a_tt], [x])
compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(tt.config.floatX)])
compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(theano.config.floatX)])
x = tt.sum(a_tt, axis=1)
x_fg = theano.gof.FunctionGraph([a_tt], [x])
compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(tt.config.floatX)])
compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(theano.config.floatX)])
a_tt = tt.matrix("a")
a_tt.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(tt.config.floatX)
a_tt.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(theano.config.floatX)
x = tt.prod(a_tt, axis=0)
x_fg = theano.gof.FunctionGraph([a_tt], [x])
compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(tt.config.floatX)])
compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(theano.config.floatX)])
x = tt.all(a_tt)
x_fg = theano.gof.FunctionGraph([a_tt], [x])
compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(tt.config.floatX)])
compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(theano.config.floatX)])
def test_jax_MakeVector():
......@@ -615,28 +622,32 @@ def test_jax_Dimshuffle():
x = a_tt.T
x_fg = theano.gof.FunctionGraph([a_tt], [x])
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(tt.config.floatX)])
compare_jax_and_py(
x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(theano.config.floatX)]
)
x = a_tt.dimshuffle([0, 1, "x"])
x_fg = theano.gof.FunctionGraph([a_tt], [x])
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(tt.config.floatX)])
compare_jax_and_py(
x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(theano.config.floatX)]
)
a_tt = tt.tensor(dtype=tt.config.floatX, broadcastable=[False, True])
a_tt = tt.tensor(dtype=theano.config.floatX, broadcastable=[False, True])
x = a_tt.dimshuffle((0,))
x_fg = theano.gof.FunctionGraph([a_tt], [x])
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(tt.config.floatX)])
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(theano.config.floatX)])
a_tt = tt.tensor(dtype=tt.config.floatX, broadcastable=[False, True])
a_tt = tt.tensor(dtype=theano.config.floatX, broadcastable=[False, True])
x = tt.elemwise.DimShuffle([False, True], (0,), inplace=True)(a_tt)
x_fg = theano.gof.FunctionGraph([a_tt], [x])
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(tt.config.floatX)])
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(theano.config.floatX)])
def test_jax_variadic_Scalar():
mu = tt.vector("mu", dtype=tt.config.floatX)
mu.tag.test_value = np.r_[0.1, 1.1].astype(tt.config.floatX)
tau = tt.vector("tau", dtype=tt.config.floatX)
tau.tag.test_value = np.r_[1.0, 2.0].astype(tt.config.floatX)
mu = tt.vector("mu", dtype=theano.config.floatX)
mu.tag.test_value = np.r_[0.1, 1.1].astype(theano.config.floatX)
tau = tt.vector("tau", dtype=theano.config.floatX)
tau.tag.test_value = np.r_[1.0, 2.0].astype(theano.config.floatX)
res = -tau * mu
......@@ -654,13 +665,13 @@ def test_jax_variadic_Scalar():
def test_jax_logp():
mu = tt.vector("mu")
mu.tag.test_value = np.r_[0.0, 0.0].astype(tt.config.floatX)
mu.tag.test_value = np.r_[0.0, 0.0].astype(theano.config.floatX)
tau = tt.vector("tau")
tau.tag.test_value = np.r_[1.0, 1.0].astype(tt.config.floatX)
tau.tag.test_value = np.r_[1.0, 1.0].astype(theano.config.floatX)
sigma = tt.vector("sigma")
sigma.tag.test_value = (1.0 / get_test_value(tau)).astype(tt.config.floatX)
sigma.tag.test_value = (1.0 / get_test_value(tau)).astype(theano.config.floatX)
value = tt.vector("value")
value.tag.test_value = np.r_[0.1, -10].astype(tt.config.floatX)
value.tag.test_value = np.r_[0.1, -10].astype(theano.config.floatX)
logp = (-tau * (value - mu) ** 2 + tt.log(tau / np.pi / 2.0)) / 2.0
conditions = [sigma > 0]
......@@ -674,9 +685,9 @@ def test_jax_logp():
def test_jax_multioutput():
x = tt.vector("x")
x.tag.test_value = np.r_[1.0, 2.0].astype(tt.config.floatX)
x.tag.test_value = np.r_[1.0, 2.0].astype(theano.config.floatX)
y = tt.vector("y")
y.tag.test_value = np.r_[3.0, 4.0].astype(tt.config.floatX)
y.tag.test_value = np.r_[3.0, 4.0].astype(theano.config.floatX)
w = tt.cosh(x ** 2 + y / 3.0)
v = tt.cosh(x / 3.0 + y ** 2)
......@@ -688,7 +699,7 @@ def test_jax_multioutput():
def test_nnet():
x = tt.vector("x")
x.tag.test_value = np.r_[1.0, 2.0].astype(tt.config.floatX)
x.tag.test_value = np.r_[1.0, 2.0].astype(theano.config.floatX)
out = tt.nnet.sigmoid(x)
fgraph = theano.gof.FunctionGraph([x], [out])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论