提交 89bad99d authored 作者: Frederic Bastien's avatar Frederic Bastien

fix broadcast_like when the ShapeOpt is disabled.

上级 7de7d093
...@@ -101,13 +101,16 @@ def broadcast_like(value, template, env, dtype=None): ...@@ -101,13 +101,16 @@ def broadcast_like(value, template, env, dtype=None):
value = T.as_tensor_variable(value) value = T.as_tensor_variable(value)
if value.type == template.type: if value.type == template.type:
return value return value
shape_of = env.shape_feature.shape_of if template not in env.variables:
if template not in shape_of:
raise NotImplementedError('broadcast_like currently requires the ' raise NotImplementedError('broadcast_like currently requires the '
'template Variable to be in the env already') 'template Variable to be in the env already')
if hasattr(env, 'shape_feature'):
new_shape = env.shape_feature.shape_of[template]
else:
new_shape = template.shape
if dtype is None: if dtype is None:
dtype = template.dtype dtype = template.dtype
rval = T.alloc(T.cast(value, dtype), *shape_of[template]) rval = T.alloc(T.cast(value, dtype), *new_shape)
# the template may have 1s in its shape without being broadcastable # the template may have 1s in its shape without being broadcastable
if rval.broadcastable != template.broadcastable: if rval.broadcastable != template.broadcastable:
rval = T.unbroadcast(rval, *[i for i in xrange(rval.ndim) rval = T.unbroadcast(rval, *[i for i in xrange(rval.ndim)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论