Unverified 提交 5c8afaef authored 作者: Tanish's avatar Tanish 提交者: GitHub

added check for square matrix in make_node to Det (#861)

上级 ffca031c
......@@ -198,7 +198,15 @@ class Det(Op):
def make_node(self, x):
x = as_tensor_variable(x)
assert x.ndim == 2
if x.ndim != 2:
raise ValueError(
f"Input passed is not a valid 2D matrix. Current ndim {x.ndim} != 2"
)
# Check for known shapes and square matrix
if None not in x.type.shape and (x.type.shape[0] != x.type.shape[1]):
raise ValueError(
f"Determinant not defined for non-square matrix inputs. Shape received is {x.type.shape}"
)
o = scalar(dtype=x.dtype)
return Apply(self, [x], [o])
......
......@@ -365,6 +365,11 @@ def test_det():
assert np.allclose(np.linalg.det(r), f(r))
def test_det_non_square_raises():
with pytest.raises(ValueError, match="Determinant not defined"):
det(tensor("x", shape=(5, 7)))
def test_det_grad():
rng = np.random.default_rng(utt.fetch_seed())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论