提交 dc6d59c0 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix strict type checking in aesara.tensor.basic_opt

上级 698b65d0
......@@ -76,7 +76,7 @@ from aesara.tensor.exceptions import NotScalarConstantError, ShapeError
from aesara.tensor.extra_ops import broadcast_shape
from aesara.tensor.math import all as at_all
from aesara.tensor.math import eq
from aesara.tensor.shape import Reshape, Shape, Shape_i, shape, shape_padleft
from aesara.tensor.shape import Reshape, Shape, Shape_i, shape_padleft
from aesara.tensor.sort import TopKOp
from aesara.tensor.subtensor import Subtensor, get_idx_list
from aesara.tensor.type import discrete_dtypes, integer_dtypes, lscalar
......@@ -509,7 +509,7 @@ compile.optdb.register(
def register_useless(lopt, *tags, **kwargs):
if type(lopt) == str:
if isinstance(lopt, str):
def register(inner_lopt):
return register_useless(inner_lopt, lopt, *tags, **kwargs)
......@@ -525,7 +525,7 @@ def register_useless(lopt, *tags, **kwargs):
def register_canonicalize(lopt, *tags, **kwargs):
if type(lopt) == str:
if isinstance(lopt, str):
def register(inner_lopt):
return register_canonicalize(inner_lopt, lopt, *tags, **kwargs)
......@@ -538,7 +538,7 @@ def register_canonicalize(lopt, *tags, **kwargs):
def register_stabilize(lopt, *tags, **kwargs):
if type(lopt) == str:
if isinstance(lopt, str):
def register(inner_lopt):
return register_stabilize(inner_lopt, lopt, *tags, **kwargs)
......@@ -551,7 +551,7 @@ def register_stabilize(lopt, *tags, **kwargs):
def register_specialize(lopt, *tags, **kwargs):
if type(lopt) == str:
if isinstance(lopt, str):
def register(inner_lopt):
return register_specialize(inner_lopt, lopt, *tags, **kwargs)
......@@ -564,7 +564,7 @@ def register_specialize(lopt, *tags, **kwargs):
def register_uncanonicalize(lopt, *tags, **kwargs):
if type(lopt) == str:
if isinstance(lopt, str):
def register(inner_lopt):
return register_uncanonicalize(inner_lopt, lopt, *tags, **kwargs)
......@@ -579,7 +579,7 @@ def register_uncanonicalize(lopt, *tags, **kwargs):
def register_specialize_device(lopt, *tags, **kwargs):
if type(lopt) == str:
if isinstance(lopt, str):
def register(inner_lopt):
return register_specialize_device(inner_lopt, lopt, *tags, **kwargs)
......@@ -1893,7 +1893,7 @@ compile.optdb.register(
@register_canonicalize
@local_optimizer([Shape])
def local_shape_to_shape_i(fgraph, node):
if node.op == shape:
if isinstance(node.op, Shape):
# This optimization needs ShapeOpt and fgraph.shape_feature
if not hasattr(fgraph, "shape_feature"):
return
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论