提交 e7a98de4 authored 作者: James Bergstra's avatar James Bergstra

fix to fix to fill_local_alloc

上级 06645b5c
......@@ -81,7 +81,7 @@ def scalarconsts_rest(inputs):
nonconsts.append(i)
return consts, origconsts, nonconsts
def broadcast_like(value, template, env):
def broadcast_like(value, template, env, dtype=None):
"""Return a Variable with the same shape and dtype as the template,
filled by broadcasting value through it. `value` will be casted as necessary.
......@@ -89,12 +89,15 @@ def broadcast_like(value, template, env):
shape_of = env.shape_feature.shape_of
if template not in shape_of:
raise NotImplementedError('broadcast_like currently requires the template Variable to be in the env already')
rval = T.alloc(T.cast(value, template.dtype), *shape_of[template])
if dtype is None:
dtype = template.dtype
rval = T.alloc(T.cast(value, dtype), *shape_of[template])
# the template may have 1s in its shape without being broadcastable
if rval.broadcastable != template.broadcastable:
rval = T.unbroadcast(rval, *[i for i in xrange(rval.ndim) if rval.broadcastable[i]
and not template.broadcastable[i]])
assert rval.type == template.type
assert rval.type.dtype == dtype
assert rval.type.broadcastable == template.broadcastable
return rval
......@@ -669,7 +672,7 @@ def local_fill_to_alloc(node):
rval = [T.cast(v, node.outputs[0].type.dtype)]
elif r.type.broadcastable == node.outputs[0].type.broadcastable:
# we are broadcasting v somehow, but not r
rval = [broadcast_like(v, r, node.env)]
rval = [broadcast_like(v, r, node.env, dtype=v.dtype)]
else:
# we are broadcasting both v and r,
# the output shape must be computed
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论