提交 1e480107 authored 作者: Frederic's avatar Frederic

Add unbroadcast in an opt.

上级 fcb54f8a
......@@ -2014,6 +2014,15 @@ def local_subtensor_of_alloc(node):
rval = T.alloc(nw_val, *nw_dims)
if type(rval) not in (list, tuple):
rval = [rval]
if rval[0].type != node.outputs[0].type:
#It happen that the make_node() isn't able to infer that some
#dimensions are broadcastable, but that now we can infer
#that. So we need to remove that information here.
rval[0] = theano.tensor.unbroadcast(
rval[0],
*[i for i, (b1, b2) in enumerate(zip(rval[0].broadcastable,
node.outputs[0].broadcastable))
if b1 and not b2])
return rval
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论