提交 a30f815f authored 作者: Sudarsan Mansingh's avatar Sudarsan Mansingh 提交者: Michael Osthege

Add type hint None to specify_shape

上级 d159f06d
......@@ -22,6 +22,9 @@ from pytensor.tensor.type_other import NoneConst
from pytensor.tensor.var import TensorConstant, TensorVariable
ShapeValueType = Union[None, np.integer, int, Variable]
def register_shape_c_code(type, code, version=()):
"""
Tell Shape Op how to generate C code for an PyTensor Type.
......@@ -541,9 +544,7 @@ _specify_shape = SpecifyShape()
def specify_shape(
x: Union[np.ndarray, Number, Variable],
shape: Union[
int, List[Union[int, Variable]], Tuple[Union[int, Variable]], Variable
],
shape: Union[ShapeValueType, List[ShapeValueType], Tuple[ShapeValueType]],
):
"""Specify a fixed shape for a `Variable`.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论