提交 a30f815f authored 作者: Sudarsan Mansingh's avatar Sudarsan Mansingh 提交者: Michael Osthege

Add type hint None to specify_shape

上级 d159f06d
...@@ -22,6 +22,9 @@ from pytensor.tensor.type_other import NoneConst ...@@ -22,6 +22,9 @@ from pytensor.tensor.type_other import NoneConst
from pytensor.tensor.var import TensorConstant, TensorVariable from pytensor.tensor.var import TensorConstant, TensorVariable
ShapeValueType = Union[None, np.integer, int, Variable]
def register_shape_c_code(type, code, version=()): def register_shape_c_code(type, code, version=()):
""" """
Tell Shape Op how to generate C code for an PyTensor Type. Tell Shape Op how to generate C code for an PyTensor Type.
...@@ -541,9 +544,7 @@ _specify_shape = SpecifyShape() ...@@ -541,9 +544,7 @@ _specify_shape = SpecifyShape()
def specify_shape( def specify_shape(
x: Union[np.ndarray, Number, Variable], x: Union[np.ndarray, Number, Variable],
shape: Union[ shape: Union[ShapeValueType, List[ShapeValueType], Tuple[ShapeValueType]],
int, List[Union[int, Variable]], Tuple[Union[int, Variable]], Variable
],
): ):
"""Specify a fixed shape for a `Variable`. """Specify a fixed shape for a `Variable`.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论