提交 4a8ccb6d authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename _tt suffix to _aet

上级 9ac485c5
...@@ -24,9 +24,9 @@ class Assert(COp): ...@@ -24,9 +24,9 @@ class Assert(COp):
Examples Examples
-------- --------
>>> import aesara >>> import aesara
>>> T = aesara.tensor >>> import aesara.tensor as aet
>>> x = T.vector('x') >>> x = aet.vector('x')
>>> assert_op = T.opt.Assert() >>> assert_op = aet.opt.Assert()
>>> func = aesara.function([x], assert_op(x, x.size<2)) >>> func = aesara.function([x], assert_op(x, x.size<2))
""" """
......
...@@ -419,7 +419,7 @@ def jax_funcify_Composite(op): ...@@ -419,7 +419,7 @@ def jax_funcify_Composite(op):
@jax_funcify.register(Scan) @jax_funcify.register(Scan)
def jax_funcify_Scan(op): def jax_funcify_Scan(op):
inner_fg = FunctionGraph(op.inputs, op.outputs) inner_fg = FunctionGraph(op.inputs, op.outputs)
jax_tt_inner_func = jax_funcify(inner_fg) jax_aet_inner_func = jax_funcify(inner_fg)
def scan(*outer_inputs): def scan(*outer_inputs):
scan_args = ScanArgs( scan_args = ScanArgs(
...@@ -538,7 +538,7 @@ def jax_funcify_Scan(op): ...@@ -538,7 +538,7 @@ def jax_funcify_Scan(op):
def jax_inner_func(carry, x): def jax_inner_func(carry, x):
inner_args = jax_args_to_inner_scan(op, carry, x) inner_args = jax_args_to_inner_scan(op, carry, x)
inner_scan_outs = [fn(*inner_args) for fn in jax_tt_inner_func] inner_scan_outs = [fn(*inner_args) for fn in jax_aet_inner_func]
new_carry = inner_scan_outs_to_jax_outs(op, carry, inner_scan_outs) new_carry = inner_scan_outs_to_jax_outs(op, carry, inner_scan_outs)
return new_carry, inner_scan_outs return new_carry, inner_scan_outs
......
差异被折叠。
...@@ -104,23 +104,26 @@ def test_beta_samples(): ...@@ -104,23 +104,26 @@ def test_beta_samples():
def test_normal_infer_shape(): def test_normal_infer_shape():
M_tt = iscalar("M") M_aet = iscalar("M")
M_tt.tag.test_value = 3 M_aet.tag.test_value = 3
sd_tt = scalar("sd") sd_aet = scalar("sd")
sd_tt.tag.test_value = np.array(1.0, dtype=config.floatX) sd_aet.tag.test_value = np.array(1.0, dtype=config.floatX)
test_params = [ test_params = [
([aet.as_tensor_variable(np.array(1.0, dtype=config.floatX)), sd_tt], None), ([aet.as_tensor_variable(np.array(1.0, dtype=config.floatX)), sd_aet], None),
([aet.as_tensor_variable(np.array(1.0, dtype=config.floatX)), sd_tt], (M_tt,)),
( (
[aet.as_tensor_variable(np.array(1.0, dtype=config.floatX)), sd_tt], [aet.as_tensor_variable(np.array(1.0, dtype=config.floatX)), sd_aet],
(2, M_tt), (M_aet,),
), ),
([aet.zeros((M_tt,)), sd_tt], None), (
([aet.zeros((M_tt,)), sd_tt], (M_tt,)), [aet.as_tensor_variable(np.array(1.0, dtype=config.floatX)), sd_aet],
([aet.zeros((M_tt,)), sd_tt], (2, M_tt)), (2, M_aet),
([aet.zeros((M_tt,)), aet.ones((M_tt,))], None), ),
([aet.zeros((M_tt,)), aet.ones((M_tt,))], (2, M_tt)), ([aet.zeros((M_aet,)), sd_aet], None),
([aet.zeros((M_aet,)), sd_aet], (M_aet,)),
([aet.zeros((M_aet,)), sd_aet], (2, M_aet)),
([aet.zeros((M_aet,)), aet.ones((M_aet,))], None),
([aet.zeros((M_aet,)), aet.ones((M_aet,))], (2, M_aet)),
( (
[ [
np.array([[-1, 20], [300, -4000]], dtype=config.floatX), np.array([[-1, 20], [300, -4000]], dtype=config.floatX),
...@@ -140,12 +143,12 @@ def test_normal_infer_shape(): ...@@ -140,12 +143,12 @@ def test_normal_infer_shape():
def test_normal_ShapeFeature(): def test_normal_ShapeFeature():
M_tt = iscalar("M") M_aet = iscalar("M")
M_tt.tag.test_value = 3 M_aet.tag.test_value = 3
sd_tt = scalar("sd") sd_aet = scalar("sd")
sd_tt.tag.test_value = np.array(1.0, dtype=config.floatX) sd_aet.tag.test_value = np.array(1.0, dtype=config.floatX)
d_rv = normal(aet.ones((M_tt,)), sd_tt, size=(2, M_tt)) d_rv = normal(aet.ones((M_aet,)), sd_aet, size=(2, M_aet))
d_rv.tag.test_value d_rv.tag.test_value
fg = FunctionGraph( fg = FunctionGraph(
...@@ -294,10 +297,10 @@ def test_mvnormal_samples(): ...@@ -294,10 +297,10 @@ def test_mvnormal_samples():
def test_mvnormal_ShapeFeature(): def test_mvnormal_ShapeFeature():
M_tt = iscalar("M") M_aet = iscalar("M")
M_tt.tag.test_value = 2 M_aet.tag.test_value = 2
d_rv = multivariate_normal(aet.ones((M_tt,)), aet.eye(M_tt), size=2) d_rv = multivariate_normal(aet.ones((M_aet,)), aet.eye(M_aet), size=2)
fg = FunctionGraph( fg = FunctionGraph(
[i for i in graph_inputs([d_rv]) if not isinstance(i, Constant)], [i for i in graph_inputs([d_rv]) if not isinstance(i, Constant)],
...@@ -309,7 +312,7 @@ def test_mvnormal_ShapeFeature(): ...@@ -309,7 +312,7 @@ def test_mvnormal_ShapeFeature():
s1, s2 = fg.shape_feature.shape_of[d_rv] s1, s2 = fg.shape_feature.shape_of[d_rv]
assert get_test_value(s1) == 2 assert get_test_value(s1) == 2
assert M_tt in graph_inputs([s2]) assert M_aet in graph_inputs([s2])
# Test broadcasted shapes # Test broadcasted shapes
mean = tensor(config.floatX, [True, False]) mean = tensor(config.floatX, [True, False])
...@@ -369,16 +372,16 @@ def test_dirichlet_samples(): ...@@ -369,16 +372,16 @@ def test_dirichlet_samples():
def test_dirichlet_infer_shape(): def test_dirichlet_infer_shape():
M_tt = iscalar("M") M_aet = iscalar("M")
M_tt.tag.test_value = 3 M_aet.tag.test_value = 3
test_params = [ test_params = [
([aet.ones((M_tt,))], None), ([aet.ones((M_aet,))], None),
([aet.ones((M_tt,))], (M_tt + 1,)), ([aet.ones((M_aet,))], (M_aet + 1,)),
([aet.ones((M_tt,))], (2, M_tt)), ([aet.ones((M_aet,))], (2, M_aet)),
([aet.ones((M_tt, M_tt + 1))], None), ([aet.ones((M_aet, M_aet + 1))], None),
([aet.ones((M_tt, M_tt + 1))], (M_tt + 2,)), ([aet.ones((M_aet, M_aet + 1))], (M_aet + 2,)),
([aet.ones((M_tt, M_tt + 1))], (2, M_tt + 2, M_tt + 3)), ([aet.ones((M_aet, M_aet + 1))], (2, M_aet + 2, M_aet + 3)),
] ]
for args, size in test_params: for args, size in test_params:
rv = dirichlet(*args, size=size) rv = dirichlet(*args, size=size)
...@@ -388,12 +391,12 @@ def test_dirichlet_infer_shape(): ...@@ -388,12 +391,12 @@ def test_dirichlet_infer_shape():
def test_dirichlet_ShapeFeature(): def test_dirichlet_ShapeFeature():
"""Make sure `RandomVariable.infer_shape` works with `ShapeFeature`.""" """Make sure `RandomVariable.infer_shape` works with `ShapeFeature`."""
M_tt = iscalar("M") M_aet = iscalar("M")
M_tt.tag.test_value = 2 M_aet.tag.test_value = 2
N_tt = iscalar("N") N_aet = iscalar("N")
N_tt.tag.test_value = 3 N_aet.tag.test_value = 3
d_rv = dirichlet(aet.ones((M_tt, N_tt)), name="Gamma") d_rv = dirichlet(aet.ones((M_aet, N_aet)), name="Gamma")
fg = FunctionGraph( fg = FunctionGraph(
[i for i in graph_inputs([d_rv]) if not isinstance(i, Constant)], [i for i in graph_inputs([d_rv]) if not isinstance(i, Constant)],
...@@ -404,8 +407,8 @@ def test_dirichlet_ShapeFeature(): ...@@ -404,8 +407,8 @@ def test_dirichlet_ShapeFeature():
s1, s2 = fg.shape_feature.shape_of[d_rv] s1, s2 = fg.shape_feature.shape_of[d_rv]
assert M_tt in graph_inputs([s1]) assert M_aet in graph_inputs([s1])
assert N_tt in graph_inputs([s2]) assert N_aet in graph_inputs([s2])
def test_poisson_samples(): def test_poisson_samples():
......
...@@ -56,17 +56,17 @@ def test_inplace_optimization(): ...@@ -56,17 +56,17 @@ def test_inplace_optimization():
def check_shape_lifted_rv(rv, params, size, rng): def check_shape_lifted_rv(rv, params, size, rng):
aet_params = [] aet_params = []
for p in params: for p in params:
p_tt = aet.as_tensor(p) p_aet = aet.as_tensor(p)
p_tt = p_tt.type() p_aet = p_aet.type()
p_tt.tag.test_value = p p_aet.tag.test_value = p
aet_params.append(p_tt) aet_params.append(p_aet)
aet_size = [] aet_size = []
for s in size: for s in size:
s_tt = aet.as_tensor(s) s_aet = aet.as_tensor(s)
s_tt = s_tt.type() s_aet = s_aet.type()
s_tt.tag.test_value = s s_aet.tag.test_value = s
aet_size.append(s_tt) aet_size.append(s_aet)
rv = rv(*aet_params, size=aet_size, rng=rng) rv = rv(*aet_params, size=aet_size, rng=rng)
rv_lifted = lift_rv_shapes(rv.owner) rv_lifted = lift_rv_shapes(rv.owner)
...@@ -243,22 +243,22 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol): ...@@ -243,22 +243,22 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol):
rng = shared(np.random.RandomState(1233532), borrow=False) rng = shared(np.random.RandomState(1233532), borrow=False)
dist_params_tt = [] dist_params_aet = []
for p in dist_params: for p in dist_params:
p_tt = aet.as_tensor(p).type() p_aet = aet.as_tensor(p).type()
p_tt.tag.test_value = p p_aet.tag.test_value = p
dist_params_tt.append(p_tt) dist_params_aet.append(p_aet)
size_tt = [] size_aet = []
for s in size: for s in size:
s_tt = iscalar() s_aet = iscalar()
s_tt.tag.test_value = s s_aet.tag.test_value = s
size_tt.append(s_tt) size_aet.append(s_aet)
dist_st = dist_op(*dist_params_tt, size=size_tt, rng=rng).dimshuffle(ds_order) dist_st = dist_op(*dist_params_aet, size=size_aet, rng=rng).dimshuffle(ds_order)
f_inputs = [ f_inputs = [
p for p in dist_params_tt + size_tt if not isinstance(p, (slice, Constant)) p for p in dist_params_aet + size_aet if not isinstance(p, (slice, Constant))
] ]
mode = Mode( mode = Mode(
...@@ -379,32 +379,32 @@ def test_Subtensor_lift(indices, lifted, dist_op, dist_params, size): ...@@ -379,32 +379,32 @@ def test_Subtensor_lift(indices, lifted, dist_op, dist_params, size):
rng = shared(np.random.RandomState(1233532), borrow=False) rng = shared(np.random.RandomState(1233532), borrow=False)
dist_params_tt = [] dist_params_aet = []
for p in dist_params: for p in dist_params:
p_tt = aet.as_tensor(p).type() p_aet = aet.as_tensor(p).type()
p_tt.tag.test_value = p p_aet.tag.test_value = p
dist_params_tt.append(p_tt) dist_params_aet.append(p_aet)
size_tt = [] size_aet = []
for s in size: for s in size:
s_tt = iscalar() s_aet = iscalar()
s_tt.tag.test_value = s s_aet.tag.test_value = s
size_tt.append(s_tt) size_aet.append(s_aet)
from aesara.tensor.subtensor import as_index_constant from aesara.tensor.subtensor import as_index_constant
indices_tt = () indices_aet = ()
for i in indices: for i in indices:
i_tt = as_index_constant(i) i_aet = as_index_constant(i)
if not isinstance(i_tt, slice): if not isinstance(i_aet, slice):
i_tt.tag.test_value = i i_aet.tag.test_value = i
indices_tt += (i_tt,) indices_aet += (i_aet,)
dist_st = dist_op(*dist_params_tt, size=size_tt, rng=rng)[indices_tt] dist_st = dist_op(*dist_params_aet, size=size_aet, rng=rng)[indices_aet]
f_inputs = [ f_inputs = [
p p
for p in dist_params_tt + size_tt + list(indices_tt) for p in dist_params_aet + size_aet + list(indices_aet)
if not isinstance(p, (slice, Constant)) if not isinstance(p, (slice, Constant))
] ]
......
...@@ -728,11 +728,11 @@ class TestNonzero: ...@@ -728,11 +728,11 @@ class TestNonzero:
m_symb = tensor(dtype=m.dtype, broadcastable=(False,) * m.ndim) m_symb = tensor(dtype=m.dtype, broadcastable=(False,) * m.ndim)
m_symb.tag.test_value = m m_symb.tag.test_value = m
res_tuple_tt = nonzero(m_symb, return_matrix=False) res_tuple_aet = nonzero(m_symb, return_matrix=False)
res_matrix_tt = nonzero(m_symb, return_matrix=True) res_matrix_aet = nonzero(m_symb, return_matrix=True)
res_tuple = tuple(r.tag.test_value for r in res_tuple_tt) res_tuple = tuple(r.tag.test_value for r in res_tuple_aet)
res_matrix = res_matrix_tt.tag.test_value res_matrix = res_matrix_aet.tag.test_value
assert np.allclose(res_matrix, np.vstack(np.nonzero(m))) assert np.allclose(res_matrix, np.vstack(np.nonzero(m)))
...@@ -757,9 +757,9 @@ class TestNonzero: ...@@ -757,9 +757,9 @@ class TestNonzero:
m_symb = tensor(dtype=m.dtype, broadcastable=(False,) * m.ndim) m_symb = tensor(dtype=m.dtype, broadcastable=(False,) * m.ndim)
m_symb.tag.test_value = m m_symb.tag.test_value = m
res_tt = flatnonzero(m_symb) res_aet = flatnonzero(m_symb)
result = res_tt.tag.test_value result = res_aet.tag.test_value
assert np.allclose(result, np.flatnonzero(m)) assert np.allclose(result, np.flatnonzero(m))
rand0d = np.empty(()) rand0d = np.empty(())
...@@ -780,9 +780,9 @@ class TestNonzero: ...@@ -780,9 +780,9 @@ class TestNonzero:
m_symb = tensor(dtype=m.dtype, broadcastable=(False,) * m.ndim) m_symb = tensor(dtype=m.dtype, broadcastable=(False,) * m.ndim)
m_symb.tag.test_value = m m_symb.tag.test_value = m
res_tt = nonzero_values(m_symb) res_aet = nonzero_values(m_symb)
result = res_tt.tag.test_value result = res_aet.tag.test_value
assert np.allclose(result, m[np.nonzero(m)]) assert np.allclose(result, m[np.nonzero(m)])
rand0d = np.empty(()) rand0d = np.empty(())
......
...@@ -1223,106 +1223,114 @@ def test_broadcast_shape(): ...@@ -1223,106 +1223,114 @@ def test_broadcast_shape():
x = np.array([[1], [2], [3]]) x = np.array([[1], [2], [3]])
y = np.array([4, 5, 6]) y = np.array([4, 5, 6])
b = np.broadcast(x, y) b = np.broadcast(x, y)
x_tt = aet.as_tensor_variable(x) x_aet = aet.as_tensor_variable(x)
y_tt = aet.as_tensor_variable(y) y_aet = aet.as_tensor_variable(y)
b_tt = broadcast_shape(x_tt, y_tt) b_aet = broadcast_shape(x_aet, y_aet)
assert np.array_equal([z.eval() for z in b_tt], b.shape) assert np.array_equal([z.eval() for z in b_aet], b.shape)
# Now, we try again using shapes as the inputs # Now, we try again using shapes as the inputs
# #
# This case also confirms that a broadcast dimension will # This case also confirms that a broadcast dimension will
# broadcast against a non-broadcast dimension when they're # broadcast against a non-broadcast dimension when they're
# both symbolic (i.e. we couldn't obtain constant values). # both symbolic (i.e. we couldn't obtain constant values).
b_tt = broadcast_shape( b_aet = broadcast_shape(
shape_tuple(x_tt, use_bcast=False), shape_tuple(x_aet, use_bcast=False),
shape_tuple(y_tt, use_bcast=False), shape_tuple(y_aet, use_bcast=False),
arrays_are_shapes=True, arrays_are_shapes=True,
) )
assert any( assert any(
isinstance(node.op, Assert) for node in applys_between([x_tt, y_tt], b_tt) isinstance(node.op, Assert) for node in applys_between([x_aet, y_aet], b_aet)
) )
assert np.array_equal([z.eval() for z in b_tt], b.shape) assert np.array_equal([z.eval() for z in b_aet], b.shape)
b_tt = broadcast_shape(shape_tuple(x_tt), shape_tuple(y_tt), arrays_are_shapes=True) b_aet = broadcast_shape(
assert np.array_equal([z.eval() for z in b_tt], b.shape) shape_tuple(x_aet), shape_tuple(y_aet), arrays_are_shapes=True
)
assert np.array_equal([z.eval() for z in b_aet], b.shape)
# These are all constants, so there shouldn't be any asserts in the # These are all constants, so there shouldn't be any asserts in the
# resulting graph. # resulting graph.
assert not any( assert not any(
isinstance(node.op, Assert) for node in applys_between([x_tt, y_tt], b_tt) isinstance(node.op, Assert) for node in applys_between([x_aet, y_aet], b_aet)
) )
x = np.array([1, 2, 3]) x = np.array([1, 2, 3])
y = np.array([4, 5, 6]) y = np.array([4, 5, 6])
b = np.broadcast(x, y) b = np.broadcast(x, y)
x_tt = aet.as_tensor_variable(x) x_aet = aet.as_tensor_variable(x)
y_tt = aet.as_tensor_variable(y) y_aet = aet.as_tensor_variable(y)
b_tt = broadcast_shape(x_tt, y_tt) b_aet = broadcast_shape(x_aet, y_aet)
assert np.array_equal([z.eval() for z in b_tt], b.shape) assert np.array_equal([z.eval() for z in b_aet], b.shape)
b_tt = broadcast_shape(shape_tuple(x_tt), shape_tuple(y_tt), arrays_are_shapes=True) b_aet = broadcast_shape(
assert np.array_equal([z.eval() for z in b_tt], b.shape) shape_tuple(x_aet), shape_tuple(y_aet), arrays_are_shapes=True
)
assert np.array_equal([z.eval() for z in b_aet], b.shape)
# TODO: This will work when/if we use a more sophisticated `is_same_graph` # TODO: This will work when/if we use a more sophisticated `is_same_graph`
# implementation. # implementation.
# assert not any( # assert not any(
# isinstance(node.op, Assert) # isinstance(node.op, Assert)
# for node in graph_ops([x_tt, y_tt], b_tt) # for node in graph_ops([x_aet, y_aet], b_aet)
# ) # )
x = np.empty((1, 2, 3)) x = np.empty((1, 2, 3))
y = np.array(1) y = np.array(1)
b = np.broadcast(x, y) b = np.broadcast(x, y)
x_tt = aet.as_tensor_variable(x) x_aet = aet.as_tensor_variable(x)
y_tt = aet.as_tensor_variable(y) y_aet = aet.as_tensor_variable(y)
b_tt = broadcast_shape(x_tt, y_tt) b_aet = broadcast_shape(x_aet, y_aet)
assert b_tt[0].value == 1 assert b_aet[0].value == 1
assert np.array_equal([z.eval() for z in b_tt], b.shape) assert np.array_equal([z.eval() for z in b_aet], b.shape)
assert not any( assert not any(
isinstance(node.op, Assert) for node in applys_between([x_tt, y_tt], b_tt) isinstance(node.op, Assert) for node in applys_between([x_aet, y_aet], b_aet)
) )
b_tt = broadcast_shape(shape_tuple(x_tt), shape_tuple(y_tt), arrays_are_shapes=True) b_aet = broadcast_shape(
assert np.array_equal([z.eval() for z in b_tt], b.shape) shape_tuple(x_aet), shape_tuple(y_aet), arrays_are_shapes=True
)
assert np.array_equal([z.eval() for z in b_aet], b.shape)
x = np.empty((2, 1, 3)) x = np.empty((2, 1, 3))
y = np.empty((2, 1, 1)) y = np.empty((2, 1, 1))
b = np.broadcast(x, y) b = np.broadcast(x, y)
x_tt = aet.as_tensor_variable(x) x_aet = aet.as_tensor_variable(x)
y_tt = aet.as_tensor_variable(y) y_aet = aet.as_tensor_variable(y)
b_tt = broadcast_shape(x_tt, y_tt) b_aet = broadcast_shape(x_aet, y_aet)
assert b_tt[1].value == 1 assert b_aet[1].value == 1
assert np.array_equal([z.eval() for z in b_tt], b.shape) assert np.array_equal([z.eval() for z in b_aet], b.shape)
# TODO: This will work when/if we use a more sophisticated `is_same_graph` # TODO: This will work when/if we use a more sophisticated `is_same_graph`
# implementation. # implementation.
# assert not any( # assert not any(
# isinstance(node.op, Assert) # isinstance(node.op, Assert)
# for node in graph_ops([x_tt, y_tt], b_tt) # for node in graph_ops([x_aet, y_aet], b_aet)
# ) # )
b_tt = broadcast_shape(shape_tuple(x_tt), shape_tuple(y_tt), arrays_are_shapes=True) b_aet = broadcast_shape(
assert np.array_equal([z.eval() for z in b_tt], b.shape) shape_tuple(x_aet), shape_tuple(y_aet), arrays_are_shapes=True
)
x1_shp_tt = iscalar("x1") assert np.array_equal([z.eval() for z in b_aet], b.shape)
x2_shp_tt = iscalar("x2")
y1_shp_tt = iscalar("y1") x1_shp_aet = iscalar("x1")
x_shapes = (1, x1_shp_tt, x2_shp_tt) x2_shp_aet = iscalar("x2")
x_tt = aet.ones(x_shapes) y1_shp_aet = iscalar("y1")
y_shapes = (y1_shp_tt, 1, x2_shp_tt) x_shapes = (1, x1_shp_aet, x2_shp_aet)
y_tt = aet.ones(y_shapes) x_aet = aet.ones(x_shapes)
b_tt = broadcast_shape(x_tt, y_tt) y_shapes = (y1_shp_aet, 1, x2_shp_aet)
y_aet = aet.ones(y_shapes)
b_aet = broadcast_shape(x_aet, y_aet)
# TODO: This will work when/if we use a more sophisticated `is_same_graph` # TODO: This will work when/if we use a more sophisticated `is_same_graph`
# implementation. # implementation.
# assert not any( # assert not any(
# isinstance(node.op, Assert) # isinstance(node.op, Assert)
# for node in graph_ops([x_tt, y_tt], b_tt) # for node in graph_ops([x_aet, y_aet], b_aet)
# ) # )
res = aet.as_tensor(b_tt).eval( res = aet.as_tensor(b_aet).eval(
{ {
x1_shp_tt: 10, x1_shp_aet: 10,
x2_shp_tt: 4, x2_shp_aet: 4,
y1_shp_tt: 2, y1_shp_aet: 2,
} }
) )
assert np.array_equal(res, (2, 10, 4)) assert np.array_equal(res, (2, 10, 4))
y_shapes = (y1_shp_tt, 1, y1_shp_tt) y_shapes = (y1_shp_aet, 1, y1_shp_aet)
y_tt = aet.ones(y_shapes) y_aet = aet.ones(y_shapes)
b_tt = broadcast_shape(x_tt, y_tt) b_aet = broadcast_shape(x_aet, y_aet)
assert isinstance(b_tt[-1].owner.op, Assert) assert isinstance(b_aet[-1].owner.op, Assert)
class TestBroadcastTo(utt.InferShapeTester): class TestBroadcastTo(utt.InferShapeTester):
...@@ -1348,10 +1356,10 @@ class TestBroadcastTo(utt.InferShapeTester): ...@@ -1348,10 +1356,10 @@ class TestBroadcastTo(utt.InferShapeTester):
assert bcast_res.broadcastable == (False, True) assert bcast_res.broadcastable == (False, True)
bcast_np = np.broadcast_to(5, (4, 1)) bcast_np = np.broadcast_to(5, (4, 1))
bcast_tt = bcast_res.get_test_value() bcast_aet = bcast_res.get_test_value()
assert np.array_equal(bcast_tt, bcast_np) assert np.array_equal(bcast_aet, bcast_np)
assert np.shares_memory(bcast_tt, a.get_test_value()) assert np.shares_memory(bcast_aet, a.get_test_value())
@pytest.mark.parametrize( @pytest.mark.parametrize(
"fn,input_dims", "fn,input_dims",
......
...@@ -932,9 +932,9 @@ class TestMaxAndArgmax: ...@@ -932,9 +932,9 @@ class TestMaxAndArgmax:
def test_numpy_input(self): def test_numpy_input(self):
ar = np.array([1, 2, 3]) ar = np.array([1, 2, 3])
max_tt, argmax_tt = max_and_argmax(ar, axis=None) max_aet, argmax_aet = max_and_argmax(ar, axis=None)
assert max_tt.eval(), 3 assert max_aet.eval(), 3
assert argmax_tt.eval(), 2 assert argmax_aet.eval(), 2
class TestArgminArgmax: class TestArgminArgmax:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论