提交 84e69fc8 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Brandon T. Willard

Allow TensorType convert_variable to work with mixed static shapes

上级 7d0edb87
......@@ -322,16 +322,13 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
def convert_variable(self, var):
if self.is_super(var.type):
# `var.type` is at least as specific as `self`, so we return
# `var` as-is
# `var.type` is as specific as `self`, so we return `var` as-is
return var
elif var.type.is_super(self):
# `var.type` is less specific than `self`, so we convert
# `var` to `self`'s `Type`.
# Note that, in this case, `var.type != self`, because that's
# covered by the branch above.
# Use the more specific static shape information of the two
if (self.ndim == var.type.ndim) and (self.dtype == var.type.dtype):
# `var.type` only differs from `self` in that its shape is (at least partially)
# less specific than `self`, so we convert `var` to `self`'s `Type`.
# `specify_shape` will combine the more precise shapes of the two types
return aesara.tensor.specify_shape(var, self.shape)
def value_zeros(self, shape):
......
......@@ -71,6 +71,18 @@ def test_convert_variable():
assert res is const_var
def test_convert_variable_mixed_specificity():
type1 = TensorType(config.floatX, shape=(1, None, 3))
type2 = TensorType(config.floatX, shape=(None, 5, 3))
type3 = TensorType(config.floatX, shape=(1, 5, 3))
test_var1 = type1()
test_var2 = type2()
assert type1.convert_variable(test_var2).type == type3
assert type2.convert_variable(test_var1).type == type3
def test_filter_variable():
test_type = TensorType(config.floatX, [])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论