提交 eb18f0ea authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Avoid duplicated inputs in KroneckerProduct OpFromGraph

上级 9df35e8d
...@@ -400,6 +400,15 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -400,6 +400,15 @@ class OpFromGraph(Op, HasInnerGraph):
Check :func:`pytensor.function` for more arguments, only works when not Check :func:`pytensor.function` for more arguments, only works when not
inline. inline.
""" """
ignore_unused_inputs = kwargs.get("on_unused_input", False) == "ignore"
if not ignore_unused_inputs and len(inputs) != len(set(inputs)):
var_counts = {var: inputs.count(var) for var in inputs}
duplicated_inputs = [var for var, count in var_counts.items() if count > 1]
raise ValueError(
f"There following variables were provided more than once as inputs to the OpFromGraph, resulting in an "
f"invalid graph: {duplicated_inputs}. Use dummy variables or var.copy() to distinguish "
f"variables when creating the OpFromGraph graph."
)
if not (isinstance(inputs, list) and isinstance(outputs, list)): if not (isinstance(inputs, list) and isinstance(outputs, list)):
raise TypeError("Inputs and outputs must be lists") raise TypeError("Inputs and outputs must be lists")
......
...@@ -1034,6 +1034,11 @@ def kron(a, b): ...@@ -1034,6 +1034,11 @@ def kron(a, b):
""" """
a = as_tensor_variable(a) a = as_tensor_variable(a)
b = as_tensor_variable(b) b = as_tensor_variable(b)
if a is b:
# In case a is the same as b, we need a different variable to build the OFG
b = a.copy()
if a.ndim + b.ndim <= 2: if a.ndim + b.ndim <= 2:
raise TypeError( raise TypeError(
"kron: inputs dimensions must sum to 3 or more. " "kron: inputs dimensions must sum to 3 or more. "
......
...@@ -118,6 +118,7 @@ class TestOpFromGraph(unittest_tools.InferShapeTester): ...@@ -118,6 +118,7 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
f = op(x, y, z) f = op(x, y, z)
f = f - grad(pt_sum(f), y) f = f - grad(pt_sum(f), y)
f = f - grad(pt_sum(f), y) f = f - grad(pt_sum(f), y)
fn = function([x, y, z], f) fn = function([x, y, z], f)
xv = np.ones((2, 2), dtype=config.floatX) xv = np.ones((2, 2), dtype=config.floatX)
yv = np.ones((2, 2), dtype=config.floatX) * 3 yv = np.ones((2, 2), dtype=config.floatX) * 3
...@@ -584,6 +585,22 @@ class TestOpFromGraph(unittest_tools.InferShapeTester): ...@@ -584,6 +585,22 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
out = test_ofg(y, y) out = test_ofg(y, y)
assert out.eval() == 4 assert out.eval() == 4
def test_repeated_inputs(self):
x = pt.dscalar("x")
y = pt.dscalar("y")
with pytest.raises(
ValueError,
match="There following variables were provided more than once as inputs to the "
"OpFromGraph",
):
OpFromGraph([x, x, y], [x + y])
# Test that repeated inputs will be allowed if unused inputs are ignored
g = OpFromGraph([x, x, y], [x + y], on_unused_input="ignore")
f = g(x, x, y)
assert f.eval({x: 5, y: 5}) == 10
@config.change_flags(floatX="float64") @config.change_flags(floatX="float64")
def test_debugprint(): def test_debugprint():
......
...@@ -514,8 +514,8 @@ def test_expm_grad_3(): ...@@ -514,8 +514,8 @@ def test_expm_grad_3():
def test_solve_discrete_lyapunov_via_direct_real(): def test_solve_discrete_lyapunov_via_direct_real():
N = 5 N = 5
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
a = pt.dmatrix() a = pt.dmatrix("a")
q = pt.dmatrix() q = pt.dmatrix("q")
f = function([a, q], [solve_discrete_lyapunov(a, q, method="direct")]) f = function([a, q], [solve_discrete_lyapunov(a, q, method="direct")])
A = rng.normal(size=(N, N)) A = rng.normal(size=(N, N))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论