提交 4d74d13a authored 作者: Brendan Murphy's avatar Brendan Murphy 提交者: Ricardo Vieira

Updated doctests

From numpy PR https://github.com/numpy/numpy/pull/22449, the repr of scalar values has changed, e.g. from "1" to "np.int64(1)", which caused two doctests to fail.
上级 999a62c5
...@@ -256,7 +256,7 @@ def _general_dot( ...@@ -256,7 +256,7 @@ def _general_dot(
.. testoutput:: .. testoutput::
(3, 4, 2) (np.int64(3), np.int64(4), np.int64(2))
""" """
# Shortcut for non batched case # Shortcut for non batched case
if not batch_axes[0] and not batch_axes[1]: if not batch_axes[0] and not batch_axes[1]:
......
...@@ -757,13 +757,15 @@ def get_constant_idx( ...@@ -757,13 +757,15 @@ def get_constant_idx(
Example usage where `v` and `a` are appropriately typed PyTensor variables : Example usage where `v` and `a` are appropriately typed PyTensor variables :
>>> from pytensor.scalar import int64 >>> from pytensor.scalar import int64
>>> from pytensor.tensor import matrix >>> from pytensor.tensor import matrix
>>> import numpy as np
>>>
>>> v = int64("v") >>> v = int64("v")
>>> a = matrix("a") >>> a = matrix("a")
>>> b = a[v, 1:3] >>> b = a[v, 1:3]
>>> b.owner.op.idx_list >>> b.owner.op.idx_list
(ScalarType(int64), slice(ScalarType(int64), ScalarType(int64), None)) (ScalarType(int64), slice(ScalarType(int64), ScalarType(int64), None))
>>> get_constant_idx(b.owner.op.idx_list, b.owner.inputs, allow_partial=True) >>> get_constant_idx(b.owner.op.idx_list, b.owner.inputs, allow_partial=True)
[v, slice(1, 3, None)] [v, slice(np.int64(1), np.int64(3), None)]
>>> get_constant_idx(b.owner.op.idx_list, b.owner.inputs) >>> get_constant_idx(b.owner.op.idx_list, b.owner.inputs)
Traceback (most recent call last): Traceback (most recent call last):
pytensor.tensor.exceptions.NotScalarConstantError pytensor.tensor.exceptions.NotScalarConstantError
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论