提交 9a453338 authored 作者: James Bergstra's avatar James Bergstra

fixed optimization bugs triggered by propagation of shape 1 to unbroadcastable…

fixed optimization bugs triggered by propagation of shape 1 to unbroadcastable dims - hint: prefer broadcast_like() to alloc()
上级 d99d6d32
......@@ -90,6 +90,10 @@ def broadcast_like(value, template, env):
if template not in shape_of:
raise NotImplementedError('broadcast_like currently requires the template Variable to be in the env already')
rval = T.alloc(T.cast(value, template.dtype), *shape_of[template])
# the template may have 1s in its shape without being broadcastable
if rval.broadcastable != template.broadcastable:
rval = T.unbroadcast(rval, *[i for i in xrange(rval.ndim) if rval.broadcastable[i]
and not template.broadcastable[i]])
assert rval.type == template.type
return rval
......@@ -663,14 +667,20 @@ def local_fill_to_alloc(node):
elif v.type.broadcastable == node.outputs[0].type.broadcastable:
# this is a cast
rval = [T.cast(v, node.outputs[0].type.dtype)]
elif r.type.broadcastable == node.outputs[0].type.broadcastable:
# we are broadcasting v somehow, but not r
rval = [broadcast_like(v, r, node.env)]
else:
# we are broadcasting v somehow
shape_of = node.env.shape_feature.shape_of
# we are broadcasting both v and r,
# the output shape must be computed
#
# TODO: implement this case (including a test!)
#
# I think the strategy should be to extend the shorter shape vector
# with 1s (how?) and then take the elementwise max of the two.
# - how to flag an error of shape mismatch where broadcasting should be illegal?
return
# TODO: cut out un-necessary dimshuffles of v
rval = [T.alloc(T.cast(v, node.outputs[0].dtype), *shape_of[node.outputs[0]])]
#if rval[0].type != node.outputs[0].type:
#print >> sys.stderr, theano.printing.debugprint(node.outputs[0], file='str')
assert rval[0].type == node.outputs[0].type, ('rval', rval[0].type,
'orig', node.outputs[0].type,
......@@ -2259,8 +2269,7 @@ def local_mul_specialize(node):
neg ^= True #toggles
elif N.all(y == 0.0):
# if we find any zero, we just return right away
return [T.alloc(numpy.asarray(0, dtype=node.outputs[0].dtype),
*node.env.shape_feature.shape_of[node.outputs[0]])]
return [broadcast_like(0, node.outputs[0], node.env)]
else:
new_inputs.append(input)
......@@ -2277,21 +2286,14 @@ def local_mul_specialize(node):
else:
rval = T.mul(*new_inputs)
return [T.alloc(T.cast(rval, node.outputs[0].dtype),
*node.env.shape_feature.shape_of[node.outputs[0]])]
return [broadcast_like(rval, node.outputs[0], node.env)]
else:
# there are no variable inputs to mul
# N.B. this could have been constant-folded...
if neg:
# return output's worth of -1
return [T.alloc(
numpy.asarray(-1, dtype=node.outputs[0].dtype),
*node.env.shape_feature.shape_of[node.outputs[0]])]
return [broadcast_like(-1, node.outputs[0], node.env)]
else:
# return output's worth of 1
return [T.alloc(
numpy.asarray(1, dtype=node.outputs[0].dtype),
*node.env.shape_feature.shape_of[node.outputs[0]])]
return [broadcast_like(1, node.outputs[0], node.env)]
register_specialize(local_mul_specialize)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论