提交 b27b97b6 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Prevent ShapeFeature.make_vector_shape from creating no-ops

上级 810ee8f4
......@@ -60,7 +60,6 @@ from aesara.tensor.basic import (
get_scalar_constant_value,
get_vector_length,
join,
make_vector,
ones_like,
patternbroadcast,
switch,
......@@ -77,7 +76,7 @@ from aesara.tensor.math import eq
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, shape_padleft
from aesara.tensor.sort import TopKOp
from aesara.tensor.subtensor import Subtensor, get_idx_list
from aesara.tensor.type import TensorType, discrete_dtypes, integer_dtypes, lscalar
from aesara.tensor.type import TensorType, discrete_dtypes, integer_dtypes
from aesara.tensor.var import TensorConstant
from aesara.utils import NoDuplicateOptWarningFilter
......@@ -1191,7 +1190,7 @@ class ShapeFeature(features.Feature):
self.set_shape(r, self.shape_tuple(r))
def make_vector_shape(self, r):
return make_vector(*self.shape_of[r])
return as_tensor_variable(self.shape_of[r], ndim=1, dtype="int64")
def on_attach(self, fgraph):
......@@ -1207,7 +1206,7 @@ class ShapeFeature(features.Feature):
# Must be local to the object as otherwise we reuse the same
# variable for multiple fgraph!
self.lscalar_one = constant(1, dtype="int64")
assert self.lscalar_one.type == lscalar
assert self.lscalar_one.type.dtype == "int64"
self.fgraph = fgraph
# Variable -> tuple(scalars) or None (All tensor vars map to tuple)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论