提交 b5d6f92a authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Move `test_local_subtensor_of_dot` to `test_subtensor_lift`

上级 4d539fa5
...@@ -1203,52 +1203,6 @@ def test_local_log_add_exp(): ...@@ -1203,52 +1203,6 @@ def test_local_log_add_exp():
# TODO: test that the rewrite works in the presence of broadcasting. # TODO: test that the rewrite works in the presence of broadcasting.
def test_local_subtensor_of_dot():
m1 = matrix()
m2 = matrix()
d1 = np.arange(6).reshape((3, 2)).astype(config.floatX)
d2 = np.arange(8).reshape((2, 4)).astype(config.floatX) + 10
mode = get_default_mode().including("local_subtensor_of_dot")
def test_equality(a, b):
return a.shape == b.shape and np.allclose(a, b)
# [cst]
f = function([m1, m2], pytensor.tensor.dot(m1, m2)[1], mode=mode)
topo = f.maker.fgraph.toposort()
assert test_equality(f(d1, d2), np.dot(d1, d2)[1])
# DimShuffle happen in FAST_COMPILE
assert isinstance(topo[-1].op, CGemv | Gemv | DimShuffle)
# slice
f = function([m1, m2], pytensor.tensor.dot(m1, m2)[1:2], mode=mode)
topo = f.maker.fgraph.toposort()
assert test_equality(f(d1, d2), np.dot(d1, d2)[1:2])
assert isinstance(topo[-1].op, Dot22)
m1 = tensor3()
m2 = tensor3()
idx = iscalar()
d1 = np.arange(30).reshape(2, 5, 3).astype(config.floatX)
d2 = np.arange(72).reshape(4, 3, 6).astype(config.floatX) + 100
f = function(
[m1, m2, idx], pytensor.tensor.dot(m1, m2)[idx, 1:4, :, idx:], mode=mode
)
assert test_equality(f(d1, d2, 1), np.dot(d1, d2)[1, 1:4, :, 1:])
# if we return the gradients. We need to use same mode as before.
assert check_stack_trace(f, ops_to_check="last")
f = function(
[m1, m2, idx], pytensor.tensor.dot(m1, m2)[1:4, :, idx:, idx], mode=mode
)
assert test_equality(f(d1, d2, 1), np.dot(d1, d2)[1:4, :, 1:, 1])
# Now test that the stack trace is copied over properly,
# if we return the gradients. We need to use same mode as before.
assert check_stack_trace(f, ops_to_check="last")
def test_local_elemwise_sub_zeros(): def test_local_elemwise_sub_zeros():
scal = scalar() scal = scalar()
vect = vector() vect = vector()
......
...@@ -37,6 +37,8 @@ from pytensor.tensor import ( ...@@ -37,6 +37,8 @@ from pytensor.tensor import (
vector, vector,
) )
from pytensor.tensor.basic import MakeVector, concatenate, expand_dims, make_vector from pytensor.tensor.basic import MakeVector, concatenate, expand_dims, make_vector
from pytensor.tensor.blas import Dot22, Gemv
from pytensor.tensor.blas_c import CGemv
from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.math import sum as pt_sum from pytensor.tensor.math import sum as pt_sum
from pytensor.tensor.rewriting.subtensor_lift import ( from pytensor.tensor.rewriting.subtensor_lift import (
...@@ -178,6 +180,48 @@ class TestLocalSubtensorOfElemwise: ...@@ -178,6 +180,48 @@ class TestLocalSubtensorOfElemwise:
assert local_subtensor_of_elemwise.transform(fgraph, out2.owner) is not None assert local_subtensor_of_elemwise.transform(fgraph, out2.owner) is not None
def test_local_subtensor_of_dot():
m1 = matrix()
m2 = matrix()
d1 = np.arange(6).reshape((3, 2)).astype(config.floatX)
d2 = np.arange(8).reshape((2, 4)).astype(config.floatX) + 10
mode = get_default_mode().including("local_subtensor_of_dot")
def test_equality(a, b):
return a.shape == b.shape and np.allclose(a, b)
# [cst]
f = function([m1, m2], pt.dot(m1, m2)[1], mode=mode)
topo = f.maker.fgraph.toposort()
assert test_equality(f(d1, d2), np.dot(d1, d2)[1])
# DimShuffle happen in FAST_COMPILE
assert isinstance(topo[-1].op, CGemv | Gemv | DimShuffle)
# slice
f = function([m1, m2], pt.dot(m1, m2)[1:2], mode=mode)
topo = f.maker.fgraph.toposort()
assert test_equality(f(d1, d2), np.dot(d1, d2)[1:2])
assert isinstance(topo[-1].op, Dot22)
m1 = tensor3()
m2 = tensor3()
idx = iscalar()
d1 = np.arange(30).reshape(2, 5, 3).astype(config.floatX)
d2 = np.arange(72).reshape(4, 3, 6).astype(config.floatX) + 100
f = function([m1, m2, idx], pt.dot(m1, m2)[idx, 1:4, :, idx:], mode=mode)
assert test_equality(f(d1, d2, 1), np.dot(d1, d2)[1, 1:4, :, 1:])
# if we return the gradients. We need to use same mode as before.
assert check_stack_trace(f, ops_to_check="last")
f = function([m1, m2, idx], pt.dot(m1, m2)[1:4, :, idx:, idx], mode=mode)
assert test_equality(f(d1, d2, 1), np.dot(d1, d2)[1:4, :, 1:, 1])
# Now test that the stack trace is copied over properly,
# if we return the gradients. We need to use same mode as before.
assert check_stack_trace(f, ops_to_check="last")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"original_fn, expected_fn", "original_fn, expected_fn",
[ [
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论