提交 471657a5 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Clean up some usage of the TensorType interface in Scan

上级 807c0c97
...@@ -161,8 +161,9 @@ def check_broadcast(v1, v2): ...@@ -161,8 +161,9 @@ def check_broadcast(v1, v2):
which may wrongly be interpreted as broadcastable. which may wrongly be interpreted as broadcastable.
""" """
if not hasattr(v1, "broadcastable") and not hasattr(v2, "broadcastable"): if not isinstance(v1.type, TensorType) and not isinstance(v2.type, TensorType):
return return
msg = ( msg = (
"The broadcast pattern of the output of scan (%s) is " "The broadcast pattern of the output of scan (%s) is "
"inconsistent with the one provided in `output_info` " "inconsistent with the one provided in `output_info` "
...@@ -173,13 +174,13 @@ def check_broadcast(v1, v2): ...@@ -173,13 +174,13 @@ def check_broadcast(v1, v2):
"them consistent, e.g. using aesara.tensor." "them consistent, e.g. using aesara.tensor."
"{unbroadcast, specify_broadcastable}." "{unbroadcast, specify_broadcastable}."
) )
size = min(len(v1.broadcastable), len(v2.broadcastable)) size = min(v1.type.ndim, v2.type.ndim)
for n, (b1, b2) in enumerate( for n, (b1, b2) in enumerate(
zip(v1.broadcastable[-size:], v2.broadcastable[-size:]) zip(v1.type.broadcastable[-size:], v2.type.broadcastable[-size:])
): ):
if b1 != b2: if b1 != b2:
a1 = n + size - len(v1.broadcastable) + 1 a1 = n + size - v1.type.ndim + 1
a2 = n + size - len(v2.broadcastable) + 1 a2 = n + size - v2.type.ndim + 1
raise TypeError(msg % (v1.type, v2.type, a1, b1, b2, a2)) raise TypeError(msg % (v1.type, v2.type, a1, b1, b2, a2))
...@@ -628,6 +629,7 @@ class ScanMethodsMixin: ...@@ -628,6 +629,7 @@ class ScanMethodsMixin:
type_input = self.inner_inputs[inner_iidx].type type_input = self.inner_inputs[inner_iidx].type
type_output = self.inner_outputs[inner_oidx].type type_output = self.inner_outputs[inner_oidx].type
if ( if (
# TODO: Use the `Type` interface for this
type_input.dtype != type_output.dtype type_input.dtype != type_output.dtype
or type_input.broadcastable != type_output.broadcastable or type_input.broadcastable != type_output.broadcastable
): ):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论