提交 4a8ccb6d authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename _tt suffix to _aet

上级 9ac485c5
......@@ -24,9 +24,9 @@ class Assert(COp):
Examples
--------
>>> import aesara
>>> T = aesara.tensor
>>> x = T.vector('x')
>>> assert_op = T.opt.Assert()
>>> import aesara.tensor as aet
>>> x = aet.vector('x')
>>> assert_op = aet.opt.Assert()
>>> func = aesara.function([x], assert_op(x, x.size<2))
"""
......
......@@ -419,7 +419,7 @@ def jax_funcify_Composite(op):
@jax_funcify.register(Scan)
def jax_funcify_Scan(op):
inner_fg = FunctionGraph(op.inputs, op.outputs)
jax_tt_inner_func = jax_funcify(inner_fg)
jax_aet_inner_func = jax_funcify(inner_fg)
def scan(*outer_inputs):
scan_args = ScanArgs(
......@@ -538,7 +538,7 @@ def jax_funcify_Scan(op):
def jax_inner_func(carry, x):
inner_args = jax_args_to_inner_scan(op, carry, x)
inner_scan_outs = [fn(*inner_args) for fn in jax_tt_inner_func]
inner_scan_outs = [fn(*inner_args) for fn in jax_aet_inner_func]
new_carry = inner_scan_outs_to_jax_outs(op, carry, inner_scan_outs)
return new_carry, inner_scan_outs
......
......@@ -424,7 +424,7 @@ def test_jax_scan_multiple_output():
def test_jax_scan_tap_output():
a_tt = scalar("a")
a_aet = scalar("a")
def input_step_fn(y_tm1, y_tm3, a):
y_tm1.name = "y_tm1"
......@@ -433,7 +433,7 @@ def test_jax_scan_tap_output():
res.name = "y_t"
return res
y_scan_tt, _ = scan(
y_scan_aet, _ = scan(
fn=input_step_fn,
outputs_info=[
{
......@@ -443,14 +443,14 @@ def test_jax_scan_tap_output():
"taps": [-1, -3],
},
],
non_sequences=[a_tt],
non_sequences=[a_aet],
n_steps=10,
name="y_scan",
)
y_scan_tt.name = "y"
y_scan_tt.owner.inputs[0].name = "y_all"
y_scan_aet.name = "y"
y_scan_aet.owner.inputs[0].name = "y_all"
out_fg = FunctionGraph([a_tt], [y_scan_tt])
out_fg = FunctionGraph([a_aet], [y_scan_aet])
test_input_vals = [np.array(10.0).astype(config.floatX)]
compare_jax_and_py(out_fg, test_input_vals)
......@@ -458,140 +458,140 @@ def test_jax_scan_tap_output():
def test_jax_Subtensors():
# Basic indices
x_tt = aet.arange(3 * 4 * 5).reshape((3, 4, 5))
out_tt = x_tt[1, 2, 0]
assert isinstance(out_tt.owner.op, aet_subtensor.Subtensor)
out_fg = FunctionGraph([], [out_tt])
x_aet = aet.arange(3 * 4 * 5).reshape((3, 4, 5))
out_aet = x_aet[1, 2, 0]
assert isinstance(out_aet.owner.op, aet_subtensor.Subtensor)
out_fg = FunctionGraph([], [out_aet])
compare_jax_and_py(out_fg, [])
out_tt = x_tt[1:2, 1, :]
assert isinstance(out_tt.owner.op, aet_subtensor.Subtensor)
out_fg = FunctionGraph([], [out_tt])
out_aet = x_aet[1:2, 1, :]
assert isinstance(out_aet.owner.op, aet_subtensor.Subtensor)
out_fg = FunctionGraph([], [out_aet])
compare_jax_and_py(out_fg, [])
# Boolean indices
out_tt = x_tt[x_tt < 0]
assert isinstance(out_tt.owner.op, aet_subtensor.AdvancedSubtensor)
out_fg = FunctionGraph([], [out_tt])
out_aet = x_aet[x_aet < 0]
assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedSubtensor)
out_fg = FunctionGraph([], [out_aet])
compare_jax_and_py(out_fg, [])
# Advanced indexing
out_tt = x_tt[[1, 2]]
assert isinstance(out_tt.owner.op, aet_subtensor.AdvancedSubtensor1)
out_fg = FunctionGraph([], [out_tt])
out_aet = x_aet[[1, 2]]
assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedSubtensor1)
out_fg = FunctionGraph([], [out_aet])
compare_jax_and_py(out_fg, [])
out_tt = x_tt[[1, 2], [2, 3]]
assert isinstance(out_tt.owner.op, aet_subtensor.AdvancedSubtensor)
out_fg = FunctionGraph([], [out_tt])
out_aet = x_aet[[1, 2], [2, 3]]
assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedSubtensor)
out_fg = FunctionGraph([], [out_aet])
compare_jax_and_py(out_fg, [])
# Advanced and basic indexing
out_tt = x_tt[[1, 2], :]
assert isinstance(out_tt.owner.op, aet_subtensor.AdvancedSubtensor1)
out_fg = FunctionGraph([], [out_tt])
out_aet = x_aet[[1, 2], :]
assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedSubtensor1)
out_fg = FunctionGraph([], [out_aet])
compare_jax_and_py(out_fg, [])
out_tt = x_tt[[1, 2], :, [3, 4]]
assert isinstance(out_tt.owner.op, aet_subtensor.AdvancedSubtensor)
out_fg = FunctionGraph([], [out_tt])
out_aet = x_aet[[1, 2], :, [3, 4]]
assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedSubtensor)
out_fg = FunctionGraph([], [out_aet])
compare_jax_and_py(out_fg, [])
def test_jax_IncSubtensor():
x_np = np.random.uniform(-1, 1, size=(3, 4, 5)).astype(config.floatX)
x_tt = aet.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX)
x_aet = aet.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX)
# "Set" basic indices
st_tt = aet.as_tensor_variable(np.array(-10.0, dtype=config.floatX))
out_tt = aet_subtensor.set_subtensor(x_tt[1, 2, 3], st_tt)
assert isinstance(out_tt.owner.op, aet_subtensor.IncSubtensor)
out_fg = FunctionGraph([], [out_tt])
st_aet = aet.as_tensor_variable(np.array(-10.0, dtype=config.floatX))
out_aet = aet_subtensor.set_subtensor(x_aet[1, 2, 3], st_aet)
assert isinstance(out_aet.owner.op, aet_subtensor.IncSubtensor)
out_fg = FunctionGraph([], [out_aet])
compare_jax_and_py(out_fg, [])
st_tt = aet.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX))
out_tt = aet_subtensor.set_subtensor(x_tt[:2, 0, 0], st_tt)
assert isinstance(out_tt.owner.op, aet_subtensor.IncSubtensor)
out_fg = FunctionGraph([], [out_tt])
st_aet = aet.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX))
out_aet = aet_subtensor.set_subtensor(x_aet[:2, 0, 0], st_aet)
assert isinstance(out_aet.owner.op, aet_subtensor.IncSubtensor)
out_fg = FunctionGraph([], [out_aet])
compare_jax_and_py(out_fg, [])
out_tt = aet_subtensor.set_subtensor(x_tt[0, 1:3, 0], st_tt)
assert isinstance(out_tt.owner.op, aet_subtensor.IncSubtensor)
out_fg = FunctionGraph([], [out_tt])
out_aet = aet_subtensor.set_subtensor(x_aet[0, 1:3, 0], st_aet)
assert isinstance(out_aet.owner.op, aet_subtensor.IncSubtensor)
out_fg = FunctionGraph([], [out_aet])
compare_jax_and_py(out_fg, [])
# "Set" advanced indices
st_tt = aet.as_tensor_variable(
st_aet = aet.as_tensor_variable(
np.random.uniform(-1, 1, size=(2, 4, 5)).astype(config.floatX)
)
out_tt = aet_subtensor.set_subtensor(x_tt[np.r_[0, 2]], st_tt)
assert isinstance(out_tt.owner.op, aet_subtensor.AdvancedIncSubtensor1)
out_fg = FunctionGraph([], [out_tt])
out_aet = aet_subtensor.set_subtensor(x_aet[np.r_[0, 2]], st_aet)
assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor1)
out_fg = FunctionGraph([], [out_aet])
compare_jax_and_py(out_fg, [])
st_tt = aet.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX))
out_tt = aet_subtensor.set_subtensor(x_tt[[0, 2], 0, 0], st_tt)
assert isinstance(out_tt.owner.op, aet_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_tt])
st_aet = aet.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX))
out_aet = aet_subtensor.set_subtensor(x_aet[[0, 2], 0, 0], st_aet)
assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_aet])
compare_jax_and_py(out_fg, [])
st_tt = aet.as_tensor_variable(x_np[[0, 2], 0, :3])
out_tt = aet_subtensor.set_subtensor(x_tt[[0, 2], 0, :3], st_tt)
assert isinstance(out_tt.owner.op, aet_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_tt])
st_aet = aet.as_tensor_variable(x_np[[0, 2], 0, :3])
out_aet = aet_subtensor.set_subtensor(x_aet[[0, 2], 0, :3], st_aet)
assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_aet])
compare_jax_and_py(out_fg, [])
# "Set" boolean indices
mask_tt = aet.as_tensor_variable(x_np) > 0
out_tt = aet_subtensor.set_subtensor(x_tt[mask_tt], 0.0)
assert isinstance(out_tt.owner.op, aet_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_tt])
mask_aet = aet.as_tensor_variable(x_np) > 0
out_aet = aet_subtensor.set_subtensor(x_aet[mask_aet], 0.0)
assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_aet])
compare_jax_and_py(out_fg, [])
# "Increment" basic indices
st_tt = aet.as_tensor_variable(np.array(-10.0, dtype=config.floatX))
out_tt = aet_subtensor.inc_subtensor(x_tt[1, 2, 3], st_tt)
assert isinstance(out_tt.owner.op, aet_subtensor.IncSubtensor)
out_fg = FunctionGraph([], [out_tt])
st_aet = aet.as_tensor_variable(np.array(-10.0, dtype=config.floatX))
out_aet = aet_subtensor.inc_subtensor(x_aet[1, 2, 3], st_aet)
assert isinstance(out_aet.owner.op, aet_subtensor.IncSubtensor)
out_fg = FunctionGraph([], [out_aet])
compare_jax_and_py(out_fg, [])
st_tt = aet.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX))
out_tt = aet_subtensor.inc_subtensor(x_tt[:2, 0, 0], st_tt)
assert isinstance(out_tt.owner.op, aet_subtensor.IncSubtensor)
out_fg = FunctionGraph([], [out_tt])
st_aet = aet.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX))
out_aet = aet_subtensor.inc_subtensor(x_aet[:2, 0, 0], st_aet)
assert isinstance(out_aet.owner.op, aet_subtensor.IncSubtensor)
out_fg = FunctionGraph([], [out_aet])
compare_jax_and_py(out_fg, [])
out_tt = aet_subtensor.set_subtensor(x_tt[0, 1:3, 0], st_tt)
assert isinstance(out_tt.owner.op, aet_subtensor.IncSubtensor)
out_fg = FunctionGraph([], [out_tt])
out_aet = aet_subtensor.set_subtensor(x_aet[0, 1:3, 0], st_aet)
assert isinstance(out_aet.owner.op, aet_subtensor.IncSubtensor)
out_fg = FunctionGraph([], [out_aet])
compare_jax_and_py(out_fg, [])
# "Increment" advanced indices
st_tt = aet.as_tensor_variable(
st_aet = aet.as_tensor_variable(
np.random.uniform(-1, 1, size=(2, 4, 5)).astype(config.floatX)
)
out_tt = aet_subtensor.inc_subtensor(x_tt[np.r_[0, 2]], st_tt)
assert isinstance(out_tt.owner.op, aet_subtensor.AdvancedIncSubtensor1)
out_fg = FunctionGraph([], [out_tt])
out_aet = aet_subtensor.inc_subtensor(x_aet[np.r_[0, 2]], st_aet)
assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor1)
out_fg = FunctionGraph([], [out_aet])
compare_jax_and_py(out_fg, [])
st_tt = aet.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX))
out_tt = aet_subtensor.inc_subtensor(x_tt[[0, 2], 0, 0], st_tt)
assert isinstance(out_tt.owner.op, aet_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_tt])
st_aet = aet.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX))
out_aet = aet_subtensor.inc_subtensor(x_aet[[0, 2], 0, 0], st_aet)
assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_aet])
compare_jax_and_py(out_fg, [])
st_tt = aet.as_tensor_variable(x_np[[0, 2], 0, :3])
out_tt = aet_subtensor.inc_subtensor(x_tt[[0, 2], 0, :3], st_tt)
assert isinstance(out_tt.owner.op, aet_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_tt])
st_aet = aet.as_tensor_variable(x_np[[0, 2], 0, :3])
out_aet = aet_subtensor.inc_subtensor(x_aet[[0, 2], 0, :3], st_aet)
assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_aet])
compare_jax_and_py(out_fg, [])
# "Increment" boolean indices
mask_tt = aet.as_tensor_variable(x_np) > 0
out_tt = aet_subtensor.set_subtensor(x_tt[mask_tt], 1.0)
assert isinstance(out_tt.owner.op, aet_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_tt])
mask_aet = aet.as_tensor_variable(x_np) > 0
out_aet = aet_subtensor.set_subtensor(x_aet[mask_aet], 1.0)
assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_aet])
compare_jax_and_py(out_fg, [])
......@@ -614,37 +614,37 @@ def test_jax_ifelse():
def test_jax_CAReduce():
a_tt = vector("a")
a_tt.tag.test_value = np.r_[1, 2, 3].astype(config.floatX)
a_aet = vector("a")
a_aet.tag.test_value = np.r_[1, 2, 3].astype(config.floatX)
x = aet_sum(a_tt, axis=None)
x_fg = FunctionGraph([a_tt], [x])
x = aet_sum(a_aet, axis=None)
x_fg = FunctionGraph([a_aet], [x])
compare_jax_and_py(x_fg, [np.r_[1, 2, 3].astype(config.floatX)])
a_tt = matrix("a")
a_tt.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)
a_aet = matrix("a")
a_aet.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)
x = aet_sum(a_tt, axis=0)
x_fg = FunctionGraph([a_tt], [x])
x = aet_sum(a_aet, axis=0)
x_fg = FunctionGraph([a_aet], [x])
compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])
x = aet_sum(a_tt, axis=1)
x_fg = FunctionGraph([a_tt], [x])
x = aet_sum(a_aet, axis=1)
x_fg = FunctionGraph([a_aet], [x])
compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])
a_tt = matrix("a")
a_tt.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)
a_aet = matrix("a")
a_aet.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)
x = prod(a_tt, axis=0)
x_fg = FunctionGraph([a_tt], [x])
x = prod(a_aet, axis=0)
x_fg = FunctionGraph([a_aet], [x])
compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])
x = aet_all(a_tt)
x_fg = FunctionGraph([a_tt], [x])
x = aet_all(a_aet)
x_fg = FunctionGraph([a_aet], [x])
compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])
......@@ -679,24 +679,24 @@ def test_jax_Reshape_nonconcrete():
def test_jax_Dimshuffle():
a_tt = matrix("a")
a_aet = matrix("a")
x = a_tt.T
x_fg = FunctionGraph([a_tt], [x])
x = a_aet.T
x_fg = FunctionGraph([a_aet], [x])
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)])
x = a_tt.dimshuffle([0, 1, "x"])
x_fg = FunctionGraph([a_tt], [x])
x = a_aet.dimshuffle([0, 1, "x"])
x_fg = FunctionGraph([a_aet], [x])
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)])
a_tt = tensor(dtype=config.floatX, broadcastable=[False, True])
x = a_tt.dimshuffle((0,))
x_fg = FunctionGraph([a_tt], [x])
a_aet = tensor(dtype=config.floatX, broadcastable=[False, True])
x = a_aet.dimshuffle((0,))
x_fg = FunctionGraph([a_aet], [x])
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])
a_tt = tensor(dtype=config.floatX, broadcastable=[False, True])
x = aet_elemwise.DimShuffle([False, True], (0,), inplace=True)(a_tt)
x_fg = FunctionGraph([a_tt], [x])
a_aet = tensor(dtype=config.floatX, broadcastable=[False, True])
x = aet_elemwise.DimShuffle([False, True], (0,), inplace=True)(a_aet)
x_fg = FunctionGraph([a_aet], [x])
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])
......
......@@ -104,23 +104,26 @@ def test_beta_samples():
def test_normal_infer_shape():
M_tt = iscalar("M")
M_tt.tag.test_value = 3
sd_tt = scalar("sd")
sd_tt.tag.test_value = np.array(1.0, dtype=config.floatX)
M_aet = iscalar("M")
M_aet.tag.test_value = 3
sd_aet = scalar("sd")
sd_aet.tag.test_value = np.array(1.0, dtype=config.floatX)
test_params = [
([aet.as_tensor_variable(np.array(1.0, dtype=config.floatX)), sd_tt], None),
([aet.as_tensor_variable(np.array(1.0, dtype=config.floatX)), sd_tt], (M_tt,)),
([aet.as_tensor_variable(np.array(1.0, dtype=config.floatX)), sd_aet], None),
(
[aet.as_tensor_variable(np.array(1.0, dtype=config.floatX)), sd_tt],
(2, M_tt),
[aet.as_tensor_variable(np.array(1.0, dtype=config.floatX)), sd_aet],
(M_aet,),
),
([aet.zeros((M_tt,)), sd_tt], None),
([aet.zeros((M_tt,)), sd_tt], (M_tt,)),
([aet.zeros((M_tt,)), sd_tt], (2, M_tt)),
([aet.zeros((M_tt,)), aet.ones((M_tt,))], None),
([aet.zeros((M_tt,)), aet.ones((M_tt,))], (2, M_tt)),
(
[aet.as_tensor_variable(np.array(1.0, dtype=config.floatX)), sd_aet],
(2, M_aet),
),
([aet.zeros((M_aet,)), sd_aet], None),
([aet.zeros((M_aet,)), sd_aet], (M_aet,)),
([aet.zeros((M_aet,)), sd_aet], (2, M_aet)),
([aet.zeros((M_aet,)), aet.ones((M_aet,))], None),
([aet.zeros((M_aet,)), aet.ones((M_aet,))], (2, M_aet)),
(
[
np.array([[-1, 20], [300, -4000]], dtype=config.floatX),
......@@ -140,12 +143,12 @@ def test_normal_infer_shape():
def test_normal_ShapeFeature():
M_tt = iscalar("M")
M_tt.tag.test_value = 3
sd_tt = scalar("sd")
sd_tt.tag.test_value = np.array(1.0, dtype=config.floatX)
M_aet = iscalar("M")
M_aet.tag.test_value = 3
sd_aet = scalar("sd")
sd_aet.tag.test_value = np.array(1.0, dtype=config.floatX)
d_rv = normal(aet.ones((M_tt,)), sd_tt, size=(2, M_tt))
d_rv = normal(aet.ones((M_aet,)), sd_aet, size=(2, M_aet))
d_rv.tag.test_value
fg = FunctionGraph(
......@@ -294,10 +297,10 @@ def test_mvnormal_samples():
def test_mvnormal_ShapeFeature():
M_tt = iscalar("M")
M_tt.tag.test_value = 2
M_aet = iscalar("M")
M_aet.tag.test_value = 2
d_rv = multivariate_normal(aet.ones((M_tt,)), aet.eye(M_tt), size=2)
d_rv = multivariate_normal(aet.ones((M_aet,)), aet.eye(M_aet), size=2)
fg = FunctionGraph(
[i for i in graph_inputs([d_rv]) if not isinstance(i, Constant)],
......@@ -309,7 +312,7 @@ def test_mvnormal_ShapeFeature():
s1, s2 = fg.shape_feature.shape_of[d_rv]
assert get_test_value(s1) == 2
assert M_tt in graph_inputs([s2])
assert M_aet in graph_inputs([s2])
# Test broadcasted shapes
mean = tensor(config.floatX, [True, False])
......@@ -369,16 +372,16 @@ def test_dirichlet_samples():
def test_dirichlet_infer_shape():
M_tt = iscalar("M")
M_tt.tag.test_value = 3
M_aet = iscalar("M")
M_aet.tag.test_value = 3
test_params = [
([aet.ones((M_tt,))], None),
([aet.ones((M_tt,))], (M_tt + 1,)),
([aet.ones((M_tt,))], (2, M_tt)),
([aet.ones((M_tt, M_tt + 1))], None),
([aet.ones((M_tt, M_tt + 1))], (M_tt + 2,)),
([aet.ones((M_tt, M_tt + 1))], (2, M_tt + 2, M_tt + 3)),
([aet.ones((M_aet,))], None),
([aet.ones((M_aet,))], (M_aet + 1,)),
([aet.ones((M_aet,))], (2, M_aet)),
([aet.ones((M_aet, M_aet + 1))], None),
([aet.ones((M_aet, M_aet + 1))], (M_aet + 2,)),
([aet.ones((M_aet, M_aet + 1))], (2, M_aet + 2, M_aet + 3)),
]
for args, size in test_params:
rv = dirichlet(*args, size=size)
......@@ -388,12 +391,12 @@ def test_dirichlet_infer_shape():
def test_dirichlet_ShapeFeature():
"""Make sure `RandomVariable.infer_shape` works with `ShapeFeature`."""
M_tt = iscalar("M")
M_tt.tag.test_value = 2
N_tt = iscalar("N")
N_tt.tag.test_value = 3
M_aet = iscalar("M")
M_aet.tag.test_value = 2
N_aet = iscalar("N")
N_aet.tag.test_value = 3
d_rv = dirichlet(aet.ones((M_tt, N_tt)), name="Gamma")
d_rv = dirichlet(aet.ones((M_aet, N_aet)), name="Gamma")
fg = FunctionGraph(
[i for i in graph_inputs([d_rv]) if not isinstance(i, Constant)],
......@@ -404,8 +407,8 @@ def test_dirichlet_ShapeFeature():
s1, s2 = fg.shape_feature.shape_of[d_rv]
assert M_tt in graph_inputs([s1])
assert N_tt in graph_inputs([s2])
assert M_aet in graph_inputs([s1])
assert N_aet in graph_inputs([s2])
def test_poisson_samples():
......
......@@ -56,17 +56,17 @@ def test_inplace_optimization():
def check_shape_lifted_rv(rv, params, size, rng):
aet_params = []
for p in params:
p_tt = aet.as_tensor(p)
p_tt = p_tt.type()
p_tt.tag.test_value = p
aet_params.append(p_tt)
p_aet = aet.as_tensor(p)
p_aet = p_aet.type()
p_aet.tag.test_value = p
aet_params.append(p_aet)
aet_size = []
for s in size:
s_tt = aet.as_tensor(s)
s_tt = s_tt.type()
s_tt.tag.test_value = s
aet_size.append(s_tt)
s_aet = aet.as_tensor(s)
s_aet = s_aet.type()
s_aet.tag.test_value = s
aet_size.append(s_aet)
rv = rv(*aet_params, size=aet_size, rng=rng)
rv_lifted = lift_rv_shapes(rv.owner)
......@@ -243,22 +243,22 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol):
rng = shared(np.random.RandomState(1233532), borrow=False)
dist_params_tt = []
dist_params_aet = []
for p in dist_params:
p_tt = aet.as_tensor(p).type()
p_tt.tag.test_value = p
dist_params_tt.append(p_tt)
p_aet = aet.as_tensor(p).type()
p_aet.tag.test_value = p
dist_params_aet.append(p_aet)
size_tt = []
size_aet = []
for s in size:
s_tt = iscalar()
s_tt.tag.test_value = s
size_tt.append(s_tt)
s_aet = iscalar()
s_aet.tag.test_value = s
size_aet.append(s_aet)
dist_st = dist_op(*dist_params_tt, size=size_tt, rng=rng).dimshuffle(ds_order)
dist_st = dist_op(*dist_params_aet, size=size_aet, rng=rng).dimshuffle(ds_order)
f_inputs = [
p for p in dist_params_tt + size_tt if not isinstance(p, (slice, Constant))
p for p in dist_params_aet + size_aet if not isinstance(p, (slice, Constant))
]
mode = Mode(
......@@ -379,32 +379,32 @@ def test_Subtensor_lift(indices, lifted, dist_op, dist_params, size):
rng = shared(np.random.RandomState(1233532), borrow=False)
dist_params_tt = []
dist_params_aet = []
for p in dist_params:
p_tt = aet.as_tensor(p).type()
p_tt.tag.test_value = p
dist_params_tt.append(p_tt)
p_aet = aet.as_tensor(p).type()
p_aet.tag.test_value = p
dist_params_aet.append(p_aet)
size_tt = []
size_aet = []
for s in size:
s_tt = iscalar()
s_tt.tag.test_value = s
size_tt.append(s_tt)
s_aet = iscalar()
s_aet.tag.test_value = s
size_aet.append(s_aet)
from aesara.tensor.subtensor import as_index_constant
indices_tt = ()
indices_aet = ()
for i in indices:
i_tt = as_index_constant(i)
if not isinstance(i_tt, slice):
i_tt.tag.test_value = i
indices_tt += (i_tt,)
i_aet = as_index_constant(i)
if not isinstance(i_aet, slice):
i_aet.tag.test_value = i
indices_aet += (i_aet,)
dist_st = dist_op(*dist_params_tt, size=size_tt, rng=rng)[indices_tt]
dist_st = dist_op(*dist_params_aet, size=size_aet, rng=rng)[indices_aet]
f_inputs = [
p
for p in dist_params_tt + size_tt + list(indices_tt)
for p in dist_params_aet + size_aet + list(indices_aet)
if not isinstance(p, (slice, Constant))
]
......
......@@ -728,11 +728,11 @@ class TestNonzero:
m_symb = tensor(dtype=m.dtype, broadcastable=(False,) * m.ndim)
m_symb.tag.test_value = m
res_tuple_tt = nonzero(m_symb, return_matrix=False)
res_matrix_tt = nonzero(m_symb, return_matrix=True)
res_tuple_aet = nonzero(m_symb, return_matrix=False)
res_matrix_aet = nonzero(m_symb, return_matrix=True)
res_tuple = tuple(r.tag.test_value for r in res_tuple_tt)
res_matrix = res_matrix_tt.tag.test_value
res_tuple = tuple(r.tag.test_value for r in res_tuple_aet)
res_matrix = res_matrix_aet.tag.test_value
assert np.allclose(res_matrix, np.vstack(np.nonzero(m)))
......@@ -757,9 +757,9 @@ class TestNonzero:
m_symb = tensor(dtype=m.dtype, broadcastable=(False,) * m.ndim)
m_symb.tag.test_value = m
res_tt = flatnonzero(m_symb)
res_aet = flatnonzero(m_symb)
result = res_tt.tag.test_value
result = res_aet.tag.test_value
assert np.allclose(result, np.flatnonzero(m))
rand0d = np.empty(())
......@@ -780,9 +780,9 @@ class TestNonzero:
m_symb = tensor(dtype=m.dtype, broadcastable=(False,) * m.ndim)
m_symb.tag.test_value = m
res_tt = nonzero_values(m_symb)
res_aet = nonzero_values(m_symb)
result = res_tt.tag.test_value
result = res_aet.tag.test_value
assert np.allclose(result, m[np.nonzero(m)])
rand0d = np.empty(())
......
......@@ -1223,106 +1223,114 @@ def test_broadcast_shape():
x = np.array([[1], [2], [3]])
y = np.array([4, 5, 6])
b = np.broadcast(x, y)
x_tt = aet.as_tensor_variable(x)
y_tt = aet.as_tensor_variable(y)
b_tt = broadcast_shape(x_tt, y_tt)
assert np.array_equal([z.eval() for z in b_tt], b.shape)
x_aet = aet.as_tensor_variable(x)
y_aet = aet.as_tensor_variable(y)
b_aet = broadcast_shape(x_aet, y_aet)
assert np.array_equal([z.eval() for z in b_aet], b.shape)
# Now, we try again using shapes as the inputs
#
# This case also confirms that a broadcast dimension will
# broadcast against a non-broadcast dimension when they're
# both symbolic (i.e. we couldn't obtain constant values).
b_tt = broadcast_shape(
shape_tuple(x_tt, use_bcast=False),
shape_tuple(y_tt, use_bcast=False),
b_aet = broadcast_shape(
shape_tuple(x_aet, use_bcast=False),
shape_tuple(y_aet, use_bcast=False),
arrays_are_shapes=True,
)
assert any(
isinstance(node.op, Assert) for node in applys_between([x_tt, y_tt], b_tt)
isinstance(node.op, Assert) for node in applys_between([x_aet, y_aet], b_aet)
)
assert np.array_equal([z.eval() for z in b_tt], b.shape)
b_tt = broadcast_shape(shape_tuple(x_tt), shape_tuple(y_tt), arrays_are_shapes=True)
assert np.array_equal([z.eval() for z in b_tt], b.shape)
assert np.array_equal([z.eval() for z in b_aet], b.shape)
b_aet = broadcast_shape(
shape_tuple(x_aet), shape_tuple(y_aet), arrays_are_shapes=True
)
assert np.array_equal([z.eval() for z in b_aet], b.shape)
# These are all constants, so there shouldn't be any asserts in the
# resulting graph.
assert not any(
isinstance(node.op, Assert) for node in applys_between([x_tt, y_tt], b_tt)
isinstance(node.op, Assert) for node in applys_between([x_aet, y_aet], b_aet)
)
x = np.array([1, 2, 3])
y = np.array([4, 5, 6])
b = np.broadcast(x, y)
x_tt = aet.as_tensor_variable(x)
y_tt = aet.as_tensor_variable(y)
b_tt = broadcast_shape(x_tt, y_tt)
assert np.array_equal([z.eval() for z in b_tt], b.shape)
b_tt = broadcast_shape(shape_tuple(x_tt), shape_tuple(y_tt), arrays_are_shapes=True)
assert np.array_equal([z.eval() for z in b_tt], b.shape)
x_aet = aet.as_tensor_variable(x)
y_aet = aet.as_tensor_variable(y)
b_aet = broadcast_shape(x_aet, y_aet)
assert np.array_equal([z.eval() for z in b_aet], b.shape)
b_aet = broadcast_shape(
shape_tuple(x_aet), shape_tuple(y_aet), arrays_are_shapes=True
)
assert np.array_equal([z.eval() for z in b_aet], b.shape)
# TODO: This will work when/if we use a more sophisticated `is_same_graph`
# implementation.
# assert not any(
# isinstance(node.op, Assert)
# for node in graph_ops([x_tt, y_tt], b_tt)
# for node in graph_ops([x_aet, y_aet], b_aet)
# )
x = np.empty((1, 2, 3))
y = np.array(1)
b = np.broadcast(x, y)
x_tt = aet.as_tensor_variable(x)
y_tt = aet.as_tensor_variable(y)
b_tt = broadcast_shape(x_tt, y_tt)
assert b_tt[0].value == 1
assert np.array_equal([z.eval() for z in b_tt], b.shape)
x_aet = aet.as_tensor_variable(x)
y_aet = aet.as_tensor_variable(y)
b_aet = broadcast_shape(x_aet, y_aet)
assert b_aet[0].value == 1
assert np.array_equal([z.eval() for z in b_aet], b.shape)
assert not any(
isinstance(node.op, Assert) for node in applys_between([x_tt, y_tt], b_tt)
isinstance(node.op, Assert) for node in applys_between([x_aet, y_aet], b_aet)
)
b_tt = broadcast_shape(shape_tuple(x_tt), shape_tuple(y_tt), arrays_are_shapes=True)
assert np.array_equal([z.eval() for z in b_tt], b.shape)
b_aet = broadcast_shape(
shape_tuple(x_aet), shape_tuple(y_aet), arrays_are_shapes=True
)
assert np.array_equal([z.eval() for z in b_aet], b.shape)
x = np.empty((2, 1, 3))
y = np.empty((2, 1, 1))
b = np.broadcast(x, y)
x_tt = aet.as_tensor_variable(x)
y_tt = aet.as_tensor_variable(y)
b_tt = broadcast_shape(x_tt, y_tt)
assert b_tt[1].value == 1
assert np.array_equal([z.eval() for z in b_tt], b.shape)
x_aet = aet.as_tensor_variable(x)
y_aet = aet.as_tensor_variable(y)
b_aet = broadcast_shape(x_aet, y_aet)
assert b_aet[1].value == 1
assert np.array_equal([z.eval() for z in b_aet], b.shape)
# TODO: This will work when/if we use a more sophisticated `is_same_graph`
# implementation.
# assert not any(
# isinstance(node.op, Assert)
# for node in graph_ops([x_tt, y_tt], b_tt)
# for node in graph_ops([x_aet, y_aet], b_aet)
# )
b_tt = broadcast_shape(shape_tuple(x_tt), shape_tuple(y_tt), arrays_are_shapes=True)
assert np.array_equal([z.eval() for z in b_tt], b.shape)
x1_shp_tt = iscalar("x1")
x2_shp_tt = iscalar("x2")
y1_shp_tt = iscalar("y1")
x_shapes = (1, x1_shp_tt, x2_shp_tt)
x_tt = aet.ones(x_shapes)
y_shapes = (y1_shp_tt, 1, x2_shp_tt)
y_tt = aet.ones(y_shapes)
b_tt = broadcast_shape(x_tt, y_tt)
b_aet = broadcast_shape(
shape_tuple(x_aet), shape_tuple(y_aet), arrays_are_shapes=True
)
assert np.array_equal([z.eval() for z in b_aet], b.shape)
x1_shp_aet = iscalar("x1")
x2_shp_aet = iscalar("x2")
y1_shp_aet = iscalar("y1")
x_shapes = (1, x1_shp_aet, x2_shp_aet)
x_aet = aet.ones(x_shapes)
y_shapes = (y1_shp_aet, 1, x2_shp_aet)
y_aet = aet.ones(y_shapes)
b_aet = broadcast_shape(x_aet, y_aet)
# TODO: This will work when/if we use a more sophisticated `is_same_graph`
# implementation.
# assert not any(
# isinstance(node.op, Assert)
# for node in graph_ops([x_tt, y_tt], b_tt)
# for node in graph_ops([x_aet, y_aet], b_aet)
# )
res = aet.as_tensor(b_tt).eval(
res = aet.as_tensor(b_aet).eval(
{
x1_shp_tt: 10,
x2_shp_tt: 4,
y1_shp_tt: 2,
x1_shp_aet: 10,
x2_shp_aet: 4,
y1_shp_aet: 2,
}
)
assert np.array_equal(res, (2, 10, 4))
y_shapes = (y1_shp_tt, 1, y1_shp_tt)
y_tt = aet.ones(y_shapes)
b_tt = broadcast_shape(x_tt, y_tt)
assert isinstance(b_tt[-1].owner.op, Assert)
y_shapes = (y1_shp_aet, 1, y1_shp_aet)
y_aet = aet.ones(y_shapes)
b_aet = broadcast_shape(x_aet, y_aet)
assert isinstance(b_aet[-1].owner.op, Assert)
class TestBroadcastTo(utt.InferShapeTester):
......@@ -1348,10 +1356,10 @@ class TestBroadcastTo(utt.InferShapeTester):
assert bcast_res.broadcastable == (False, True)
bcast_np = np.broadcast_to(5, (4, 1))
bcast_tt = bcast_res.get_test_value()
bcast_aet = bcast_res.get_test_value()
assert np.array_equal(bcast_tt, bcast_np)
assert np.shares_memory(bcast_tt, a.get_test_value())
assert np.array_equal(bcast_aet, bcast_np)
assert np.shares_memory(bcast_aet, a.get_test_value())
@pytest.mark.parametrize(
"fn,input_dims",
......
......@@ -932,9 +932,9 @@ class TestMaxAndArgmax:
def test_numpy_input(self):
ar = np.array([1, 2, 3])
max_tt, argmax_tt = max_and_argmax(ar, axis=None)
assert max_tt.eval(), 3
assert argmax_tt.eval(), 2
max_aet, argmax_aet = max_and_argmax(ar, axis=None)
assert max_aet.eval(), 3
assert argmax_aet.eval(), 2
class TestArgminArgmax:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论