提交 4d74d13a authored 作者: Brendan Murphy's avatar Brendan Murphy 提交者: Ricardo Vieira

Updated doctests

From numpy PR https://github.com/numpy/numpy/pull/22449, the repr of scalar values has changed, e.g. from "1" to "np.int64(1)", which caused two doctests to fail.
上级 999a62c5
......@@ -256,7 +256,7 @@ def _general_dot(
.. testoutput::
(3, 4, 2)
(np.int64(3), np.int64(4), np.int64(2))
"""
# Shortcut for non batched case
if not batch_axes[0] and not batch_axes[1]:
......
......@@ -757,13 +757,15 @@ def get_constant_idx(
Example usage where `v` and `a` are appropriately typed PyTensor variables :
>>> from pytensor.scalar import int64
>>> from pytensor.tensor import matrix
>>> import numpy as np
>>>
>>> v = int64("v")
>>> a = matrix("a")
>>> b = a[v, 1:3]
>>> b.owner.op.idx_list
(ScalarType(int64), slice(ScalarType(int64), ScalarType(int64), None))
>>> get_constant_idx(b.owner.op.idx_list, b.owner.inputs, allow_partial=True)
[v, slice(1, 3, None)]
[v, slice(np.int64(1), np.int64(3), None)]
>>> get_constant_idx(b.owner.op.idx_list, b.owner.inputs)
Traceback (most recent call last):
pytensor.tensor.exceptions.NotScalarConstantError
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论