提交 9a453338 authored 作者: James Bergstra's avatar James Bergstra

fixed optimization bugs triggered by propagation of shape 1 to unbroadcastable…

fixed optimization bugs triggered by propagation of shape 1 to unbroadcastable dims - hint: prefer broadcast_like() to alloc()
上级 d99d6d32
...@@ -90,6 +90,10 @@ def broadcast_like(value, template, env): ...@@ -90,6 +90,10 @@ def broadcast_like(value, template, env):
if template not in shape_of: if template not in shape_of:
raise NotImplementedError('broadcast_like currently requires the template Variable to be in the env already') raise NotImplementedError('broadcast_like currently requires the template Variable to be in the env already')
rval = T.alloc(T.cast(value, template.dtype), *shape_of[template]) rval = T.alloc(T.cast(value, template.dtype), *shape_of[template])
# the template may have 1s in its shape without being broadcastable
if rval.broadcastable != template.broadcastable:
rval = T.unbroadcast(rval, *[i for i in xrange(rval.ndim) if rval.broadcastable[i]
and not template.broadcastable[i]])
assert rval.type == template.type assert rval.type == template.type
return rval return rval
...@@ -663,14 +667,20 @@ def local_fill_to_alloc(node): ...@@ -663,14 +667,20 @@ def local_fill_to_alloc(node):
elif v.type.broadcastable == node.outputs[0].type.broadcastable: elif v.type.broadcastable == node.outputs[0].type.broadcastable:
# this is a cast # this is a cast
rval = [T.cast(v, node.outputs[0].type.dtype)] rval = [T.cast(v, node.outputs[0].type.dtype)]
elif r.type.broadcastable == node.outputs[0].type.broadcastable:
# we are broadcasting v somehow, but not r
rval = [broadcast_like(v, r, node.env)]
else: else:
# we are broadcasting v somehow # we are broadcasting both v and r,
shape_of = node.env.shape_feature.shape_of # the output shape must be computed
#
# TODO: implement this case (including a test!)
#
# I think the strategy should be to extend the shorter shape vector
# with 1s (how?) and then take the elementwise max of the two.
# - how to flag an error of shape mismatch where broadcasting should be illegal?
return
# TODO: cut out un-necessary dimshuffles of v # TODO: cut out un-necessary dimshuffles of v
rval = [T.alloc(T.cast(v, node.outputs[0].dtype), *shape_of[node.outputs[0]])]
#if rval[0].type != node.outputs[0].type:
#print >> sys.stderr, theano.printing.debugprint(node.outputs[0], file='str')
assert rval[0].type == node.outputs[0].type, ('rval', rval[0].type, assert rval[0].type == node.outputs[0].type, ('rval', rval[0].type,
'orig', node.outputs[0].type, 'orig', node.outputs[0].type,
...@@ -2259,8 +2269,7 @@ def local_mul_specialize(node): ...@@ -2259,8 +2269,7 @@ def local_mul_specialize(node):
neg ^= True #toggles neg ^= True #toggles
elif N.all(y == 0.0): elif N.all(y == 0.0):
# if we find any zero, we just return right away # if we find any zero, we just return right away
return [T.alloc(numpy.asarray(0, dtype=node.outputs[0].dtype), return [broadcast_like(0, node.outputs[0], node.env)]
*node.env.shape_feature.shape_of[node.outputs[0]])]
else: else:
new_inputs.append(input) new_inputs.append(input)
...@@ -2277,21 +2286,14 @@ def local_mul_specialize(node): ...@@ -2277,21 +2286,14 @@ def local_mul_specialize(node):
else: else:
rval = T.mul(*new_inputs) rval = T.mul(*new_inputs)
return [T.alloc(T.cast(rval, node.outputs[0].dtype), return [broadcast_like(rval, node.outputs[0], node.env)]
*node.env.shape_feature.shape_of[node.outputs[0]])]
else: else:
# there are no variable inputs to mul # there are no variable inputs to mul
# N.B. this could have been constant-folded... # N.B. this could have been constant-folded...
if neg: if neg:
# return output's worth of -1 return [broadcast_like(-1, node.outputs[0], node.env)]
return [T.alloc(
numpy.asarray(-1, dtype=node.outputs[0].dtype),
*node.env.shape_feature.shape_of[node.outputs[0]])]
else: else:
# return output's worth of 1 return [broadcast_like(1, node.outputs[0], node.env)]
return [T.alloc(
numpy.asarray(1, dtype=node.outputs[0].dtype),
*node.env.shape_feature.shape_of[node.outputs[0]])]
register_specialize(local_mul_specialize) register_specialize(local_mul_specialize)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论