提交 8d9da724 authored 作者: James Bergstra's avatar James Bergstra

Uncommenting softmax optimization

上级 22226ae5
......@@ -397,41 +397,41 @@ def local_softmax_with_bias(node):
return
return [sm_bias]
if 0:
def softmax_simplifier(numerators, denominators):
for numerator in list(numerators):
#TODO: a single softmax'd vector??
if not numerator.type.dtype.startswith('float'):
continue
if not numerator.type.broadcastable == (False, False):
continue
if numerator.owner and numerator.owner.op == tensor.exp:
x = numerator.owner.inputs[0]
else:
continue
matching_denom = None
for denominator in denominators:
if denominator.owner and isinstance(denominator.owner.op, tensor.DimShuffle):
if denominator.owner.op.new_order == (0,'x'):
z = denominator.owner.inputs[0] # thing getting dimshuffled
if z.owner and isinstance(z.owner.op, tensor.Sum):
#print 'ASDF', denominator.owner.op.new_order
#print z.owner.op.axis
if z.owner.op.axis == (1,):
#print "almost there.. softmax", x, z.owner.inputs[0]
if z.owner.inputs[0] is numerator:
matching_denom = denominator
break
if matching_denom:
numerators.remove(numerator)
denominators.remove(matching_denom)
numerators.append(softmax(x))
return numerators, denominators
opt.local_mul_canonizer.add_simplifier(softmax_simplifier, 'softmax_simplifier')
def softmax_simplifier(numerators, denominators):
for numerator in list(numerators):
#TODO: a single softmax'd vector??
if not numerator.type.dtype.startswith('float'):
continue
if not numerator.type.broadcastable == (False, False):
continue
if numerator.owner and numerator.owner.op == tensor.exp:
x = numerator.owner.inputs[0]
else:
continue
matching_denom = None
for denominator in denominators:
if denominator.owner and isinstance(denominator.owner.op, tensor.DimShuffle):
if denominator.owner.op.new_order == (0,'x'):
z = denominator.owner.inputs[0] # thing getting dimshuffled
if z.owner and isinstance(z.owner.op, tensor.Sum):
#print 'ASDF', denominator.owner.op.new_order
#print z.owner.op.axis
if z.owner.op.axis == (1,):
#print "almost there.. softmax", x, z.owner.inputs[0]
if z.owner.inputs[0] is numerator:
matching_denom = denominator
break
if matching_denom:
numerators.remove(numerator)
denominators.remove(matching_denom)
numerators.append(softmax(x))
return numerators, denominators
opt.local_mul_canonizer.add_simplifier(softmax_simplifier, 'softmax_simplifier')
if 0:
def softmax_grad_simplifier(numerators, denominators):
print "mul simplify numerators"
printing.debugprint(numerators)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论