提交 5fbdf91c authored 作者: Frederic Bastien's avatar Frederic Bastien

Fix more opt broadcast pattern change. This remove warning in jenkins

上级 265165cf
...@@ -3069,14 +3069,10 @@ def local_subtensor_of_alloc(node): ...@@ -3069,14 +3069,10 @@ def local_subtensor_of_alloc(node):
if type(rval) not in (list, tuple): if type(rval) not in (list, tuple):
rval = [rval] rval = [rval]
if rval[0].type != node.outputs[0].type: if rval[0].type != node.outputs[0].type:
# It happen that the make_node() isn't able to infer that some # It happen that the make_node() isn't able to infer the same pattern.
# dimensions are broadcastable, but that now we can infer # We know it is safe, so fix that.
# that. So we need to remove that information here. rval[0] = T.patternbroadcast(rval[0], node.outputs[0].broadcastable)
rval[0] = theano.tensor.unbroadcast(
rval[0],
*[i for i, (b1, b2) in enumerate(zip(rval[0].broadcastable,
node.outputs[0].broadcastable))
if b1 and not b2])
return rval return rval
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论