提交 ec2351bf authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Allow single integer as TensorType shape

上级 58fc7106
...@@ -71,7 +71,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape): ...@@ -71,7 +71,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
def __init__( def __init__(
self, self,
dtype: str | npt.DTypeLike, dtype: str | npt.DTypeLike,
shape: Iterable[bool | int | None] | None = None, shape: Iterable[bool | int | None] | int | None = None,
name: str | None = None, name: str | None = None,
broadcastable: Iterable[bool] | None = None, broadcastable: Iterable[bool] | None = None,
): ):
...@@ -99,7 +99,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape): ...@@ -99,7 +99,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
) )
shape = broadcastable shape = broadcastable
if str(dtype) == "floatX": if dtype == "floatX":
self.dtype = config.floatX self.dtype = config.floatX
else: else:
try: try:
...@@ -118,6 +118,8 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape): ...@@ -118,6 +118,8 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
f"TensorType broadcastable/shape must be a boolean, integer or None, got {type(s)} {s}" f"TensorType broadcastable/shape must be a boolean, integer or None, got {type(s)} {s}"
) )
if isinstance(shape, int):
shape = (shape,)
self.shape = _shape = tuple(parse_bcast_and_shape(s) for s in shape) self.shape = _shape = tuple(parse_bcast_and_shape(s) for s in shape)
self.broadcastable = tuple(s == 1 for s in _shape) self.broadcastable = tuple(s == 1 for s in _shape)
self.ndim = _ndim = len(_shape) self.ndim = _ndim = len(_shape)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论