提交 c9dec33e authored 作者: abalkin's avatar abalkin

Repeat Op should not copy broadcastable flag for the axis along which it repeats.

上级 6fdfdb4e
......@@ -259,10 +259,13 @@ class RepeatOp(theano.Op):
% numpy_unsupported_dtypes), repeats.dtype)
if self.axis is None:
out_type = theano.tensor.TensorType(dtype=x.dtype,
broadcastable=[False])
broadcastable=[False]
else:
out_type = x.type
broadcastable = list(x.broadcastable)
broadcastable[self.axis] = False
out_type = theano.tensor.TensorType(x.dtype, broadcastable)
return theano.Apply(self, [x, repeats], [out_type()])
def perform(self, node, inputs, output_storage):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论