提交 ec2351bf authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Allow single integer as TensorType shape

上级 58fc7106
......@@ -71,7 +71,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
def __init__(
self,
dtype: str | npt.DTypeLike,
shape: Iterable[bool | int | None] | None = None,
shape: Iterable[bool | int | None] | int | None = None,
name: str | None = None,
broadcastable: Iterable[bool] | None = None,
):
......@@ -99,7 +99,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
)
shape = broadcastable
if str(dtype) == "floatX":
if dtype == "floatX":
self.dtype = config.floatX
else:
try:
......@@ -118,6 +118,8 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
f"TensorType broadcastable/shape must be a boolean, integer or None, got {type(s)} {s}"
)
if isinstance(shape, int):
shape = (shape,)
self.shape = _shape = tuple(parse_bcast_and_shape(s) for s in shape)
self.broadcastable = tuple(s == 1 for s in _shape)
self.ndim = _ndim = len(_shape)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论