提交 62d2c048 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix docstring and type hints for indices_from_subtensor

上级 b41197d0
...@@ -2,7 +2,7 @@ import logging ...@@ -2,7 +2,7 @@ import logging
import sys import sys
from itertools import chain, groupby from itertools import chain, groupby
from textwrap import dedent from textwrap import dedent
from typing import Iterable, List, Tuple, Union from typing import Iterable, List, Optional, Tuple, Union
import numpy as np import numpy as np
...@@ -64,19 +64,19 @@ invalid_tensor_types = ( ...@@ -64,19 +64,19 @@ invalid_tensor_types = (
def indices_from_subtensor( def indices_from_subtensor(
op_indices: Iterable[Union[ScalarConstant]], op_indices: Iterable[ScalarConstant],
idx_list: List[Union[Type, slice, Variable]], idx_list: Optional[List[Union[Type, slice, Variable]]],
) -> Tuple[Union[slice, Variable]]: ) -> Tuple[Union[slice, Variable]]:
"""Recreate the index tuple from which a ``*Subtensor**`` ``Op`` was created. """Recreate the index tuple from which a ``*Subtensor**`` ``Op`` was created.
Parameters Parameters
========== ==========
op_indices op_indices
The flattened indices obtained from ``x.owner.inputs``, when ``x`` is a The flattened indices obtained from ``x.inputs``, when ``x`` is a
``*Subtensor*`` ``Op``. ``*Subtensor*`` node.
idx_list idx_list
The values describing the types of each dimension's index. This is The values describing the types of each dimension's index. This is
obtained from ``x.owner.inputs``, when ``x`` is a ``*Subtensor*`` obtained from ``op.idx_list``, when ``op`` is a ``*Subtensor*``
``Op``. ``Op``.
Example Example
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论