提交 70cf7e3b authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Brandon T. Willard

Fix bug in static type shape of DiagonalSubtensor Op output

上级 e40c8274
...@@ -106,7 +106,10 @@ class DiagonalSubtensor(Op): ...@@ -106,7 +106,10 @@ class DiagonalSubtensor(Op):
def make_node(self, x, i0, i1): def make_node(self, x, i0, i1):
_i0 = at.as_tensor_variable(i0) _i0 = at.as_tensor_variable(i0)
_i1 = at.as_tensor_variable(i1) _i1 = at.as_tensor_variable(i1)
return Apply(self, [x, _i0, _i1], [x.type()]) # TODO: We could produce a more precise static shape output type
type_shape = (1 if shape == 1 else None for shape in x.type.shape)
out_type = at.TensorType(x.type.dtype, shape=type_shape)
return Apply(self, [x, _i0, _i1], [out_type()])
def perform(self, node, inputs, output_storage): def perform(self, node, inputs, output_storage):
xview = get_diagonal_subtensor_view(*inputs) xview = get_diagonal_subtensor_view(*inputs)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论