提交 c657bad8 authored 作者: Frederic Bastien's avatar Frederic Bastien

Create less temporary node.

上级 12e58299
...@@ -155,13 +155,16 @@ def broadcast_like(value, template, fgraph, dtype=None): ...@@ -155,13 +155,16 @@ def broadcast_like(value, template, fgraph, dtype=None):
if template not in fgraph.variables: if template not in fgraph.variables:
raise NotImplementedError('broadcast_like currently requires the ' raise NotImplementedError('broadcast_like currently requires the '
'template Variable to be in the fgraph already') 'template Variable to be in the fgraph already')
if dtype is None:
dtype = template.dtype
value = T.cast(value, dtype)
if value.type == template.type:
return value
if hasattr(fgraph, 'shape_feature'): if hasattr(fgraph, 'shape_feature'):
new_shape = fgraph.shape_feature.shape_of[template] new_shape = fgraph.shape_feature.shape_of[template]
else: else:
new_shape = template.shape new_shape = template.shape
if dtype is None: rval = T.alloc(value, *new_shape)
dtype = template.dtype
rval = T.alloc(T.cast(value, dtype), *new_shape)
# the template may have 1s in its shape without being broadcastable # the template may have 1s in its shape without being broadcastable
if rval.broadcastable != template.broadcastable: if rval.broadcastable != template.broadcastable:
rval = T.unbroadcast(rval, *[i for i in xrange(rval.ndim) rval = T.unbroadcast(rval, *[i for i in xrange(rval.ndim)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论