提交 b27b97b6 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Prevent ShapeFeature.make_vector_shape from creating no-ops

上级 810ee8f4
...@@ -60,7 +60,6 @@ from aesara.tensor.basic import ( ...@@ -60,7 +60,6 @@ from aesara.tensor.basic import (
get_scalar_constant_value, get_scalar_constant_value,
get_vector_length, get_vector_length,
join, join,
make_vector,
ones_like, ones_like,
patternbroadcast, patternbroadcast,
switch, switch,
...@@ -77,7 +76,7 @@ from aesara.tensor.math import eq ...@@ -77,7 +76,7 @@ from aesara.tensor.math import eq
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, shape_padleft from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, shape_padleft
from aesara.tensor.sort import TopKOp from aesara.tensor.sort import TopKOp
from aesara.tensor.subtensor import Subtensor, get_idx_list from aesara.tensor.subtensor import Subtensor, get_idx_list
from aesara.tensor.type import TensorType, discrete_dtypes, integer_dtypes, lscalar from aesara.tensor.type import TensorType, discrete_dtypes, integer_dtypes
from aesara.tensor.var import TensorConstant from aesara.tensor.var import TensorConstant
from aesara.utils import NoDuplicateOptWarningFilter from aesara.utils import NoDuplicateOptWarningFilter
...@@ -1191,7 +1190,7 @@ class ShapeFeature(features.Feature): ...@@ -1191,7 +1190,7 @@ class ShapeFeature(features.Feature):
self.set_shape(r, self.shape_tuple(r)) self.set_shape(r, self.shape_tuple(r))
def make_vector_shape(self, r): def make_vector_shape(self, r):
return make_vector(*self.shape_of[r]) return as_tensor_variable(self.shape_of[r], ndim=1, dtype="int64")
def on_attach(self, fgraph): def on_attach(self, fgraph):
...@@ -1207,7 +1206,7 @@ class ShapeFeature(features.Feature): ...@@ -1207,7 +1206,7 @@ class ShapeFeature(features.Feature):
# Must be local to the object as otherwise we reuse the same # Must be local to the object as otherwise we reuse the same
# variable for multiple fgraph! # variable for multiple fgraph!
self.lscalar_one = constant(1, dtype="int64") self.lscalar_one = constant(1, dtype="int64")
assert self.lscalar_one.type == lscalar assert self.lscalar_one.type.dtype == "int64"
self.fgraph = fgraph self.fgraph = fgraph
# Variable -> tuple(scalars) or None (All tensor vars map to tuple) # Variable -> tuple(scalars) or None (All tensor vars map to tuple)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论