提交 471657a5 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Clean up some usage of the TensorType interface in Scan

上级 807c0c97
......@@ -161,8 +161,9 @@ def check_broadcast(v1, v2):
which may wrongly be interpreted as broadcastable.
"""
if not hasattr(v1, "broadcastable") and not hasattr(v2, "broadcastable"):
if not isinstance(v1.type, TensorType) and not isinstance(v2.type, TensorType):
return
msg = (
"The broadcast pattern of the output of scan (%s) is "
"inconsistent with the one provided in `output_info` "
......@@ -173,13 +174,13 @@ def check_broadcast(v1, v2):
"them consistent, e.g. using aesara.tensor."
"{unbroadcast, specify_broadcastable}."
)
size = min(len(v1.broadcastable), len(v2.broadcastable))
size = min(v1.type.ndim, v2.type.ndim)
for n, (b1, b2) in enumerate(
zip(v1.broadcastable[-size:], v2.broadcastable[-size:])
zip(v1.type.broadcastable[-size:], v2.type.broadcastable[-size:])
):
if b1 != b2:
a1 = n + size - len(v1.broadcastable) + 1
a2 = n + size - len(v2.broadcastable) + 1
a1 = n + size - v1.type.ndim + 1
a2 = n + size - v2.type.ndim + 1
raise TypeError(msg % (v1.type, v2.type, a1, b1, b2, a2))
......@@ -628,6 +629,7 @@ class ScanMethodsMixin:
type_input = self.inner_inputs[inner_iidx].type
type_output = self.inner_outputs[inner_oidx].type
if (
# TODO: Use the `Type` interface for this
type_input.dtype != type_output.dtype
or type_input.broadcastable != type_output.broadcastable
):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论