提交 698b65d0 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use Type instead of CType in get_scalar_constant_value

上级 930603b0
...@@ -26,7 +26,7 @@ from aesara.gradient import DisconnectedType, grad_not_implemented, grad_undefin ...@@ -26,7 +26,7 @@ from aesara.gradient import DisconnectedType, grad_not_implemented, grad_undefin
from aesara.graph.basic import Apply, Constant, Variable from aesara.graph.basic import Apply, Constant, Variable
from aesara.graph.op import COp, Op from aesara.graph.op import COp, Op
from aesara.graph.params_type import ParamsType from aesara.graph.params_type import ParamsType
from aesara.graph.type import CType from aesara.graph.type import Type
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
from aesara.printing import min_informative_str, pprint from aesara.printing import min_informative_str, pprint
from aesara.scalar import int32 from aesara.scalar import int32
...@@ -468,7 +468,7 @@ def get_scalar_constant_value( ...@@ -468,7 +468,7 @@ def get_scalar_constant_value(
var.ndim == 0 for var in v.owner.inputs[0].owner.inputs[1:] var.ndim == 0 for var in v.owner.inputs[0].owner.inputs[1:]
): ):
idx = v.owner.op.idx_list[0] idx = v.owner.op.idx_list[0]
if isinstance(idx, CType): if isinstance(idx, Type):
idx = get_scalar_constant_value( idx = get_scalar_constant_value(
v.owner.inputs[1], max_recur=max_recur v.owner.inputs[1], max_recur=max_recur
) )
...@@ -482,7 +482,7 @@ def get_scalar_constant_value( ...@@ -482,7 +482,7 @@ def get_scalar_constant_value(
var.ndim == 1 for var in v.owner.inputs[0].owner.inputs[1:] var.ndim == 1 for var in v.owner.inputs[0].owner.inputs[1:]
): ):
idx = v.owner.op.idx_list[0] idx = v.owner.op.idx_list[0]
if isinstance(idx, CType): if isinstance(idx, Type):
idx = get_scalar_constant_value( idx = get_scalar_constant_value(
v.owner.inputs[1], max_recur=max_recur v.owner.inputs[1], max_recur=max_recur
) )
...@@ -517,7 +517,7 @@ def get_scalar_constant_value( ...@@ -517,7 +517,7 @@ def get_scalar_constant_value(
): ):
idx = v.owner.op.idx_list[0] idx = v.owner.op.idx_list[0]
if isinstance(idx, CType): if isinstance(idx, Type):
idx = get_scalar_constant_value( idx = get_scalar_constant_value(
v.owner.inputs[1], max_recur=max_recur v.owner.inputs[1], max_recur=max_recur
) )
...@@ -539,7 +539,7 @@ def get_scalar_constant_value( ...@@ -539,7 +539,7 @@ def get_scalar_constant_value(
op = owner.op op = owner.op
idx_list = op.idx_list idx_list = op.idx_list
idx = idx_list[0] idx = idx_list[0]
if isinstance(idx, CType): if isinstance(idx, Type):
idx = get_scalar_constant_value( idx = get_scalar_constant_value(
owner.inputs[1], max_recur=max_recur owner.inputs[1], max_recur=max_recur
) )
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论