提交 8d9da724 authored 作者: James Bergstra's avatar James Bergstra

Uncommenting softmax optimization

上级 22226ae5
...@@ -397,41 +397,41 @@ def local_softmax_with_bias(node): ...@@ -397,41 +397,41 @@ def local_softmax_with_bias(node):
return return
return [sm_bias] return [sm_bias]
if 0: def softmax_simplifier(numerators, denominators):
def softmax_simplifier(numerators, denominators): for numerator in list(numerators):
for numerator in list(numerators): #TODO: a single softmax'd vector??
#TODO: a single softmax'd vector?? if not numerator.type.dtype.startswith('float'):
if not numerator.type.dtype.startswith('float'): continue
continue
if not numerator.type.broadcastable == (False, False):
if not numerator.type.broadcastable == (False, False): continue
continue if numerator.owner and numerator.owner.op == tensor.exp:
if numerator.owner and numerator.owner.op == tensor.exp: x = numerator.owner.inputs[0]
x = numerator.owner.inputs[0] else:
else: continue
continue
matching_denom = None
matching_denom = None
for denominator in denominators:
for denominator in denominators: if denominator.owner and isinstance(denominator.owner.op, tensor.DimShuffle):
if denominator.owner and isinstance(denominator.owner.op, tensor.DimShuffle): if denominator.owner.op.new_order == (0,'x'):
if denominator.owner.op.new_order == (0,'x'): z = denominator.owner.inputs[0] # thing getting dimshuffled
z = denominator.owner.inputs[0] # thing getting dimshuffled if z.owner and isinstance(z.owner.op, tensor.Sum):
if z.owner and isinstance(z.owner.op, tensor.Sum): #print 'ASDF', denominator.owner.op.new_order
#print 'ASDF', denominator.owner.op.new_order #print z.owner.op.axis
#print z.owner.op.axis if z.owner.op.axis == (1,):
if z.owner.op.axis == (1,): #print "almost there.. softmax", x, z.owner.inputs[0]
#print "almost there.. softmax", x, z.owner.inputs[0] if z.owner.inputs[0] is numerator:
if z.owner.inputs[0] is numerator: matching_denom = denominator
matching_denom = denominator break
break if matching_denom:
if matching_denom: numerators.remove(numerator)
numerators.remove(numerator) denominators.remove(matching_denom)
denominators.remove(matching_denom) numerators.append(softmax(x))
numerators.append(softmax(x)) return numerators, denominators
return numerators, denominators opt.local_mul_canonizer.add_simplifier(softmax_simplifier, 'softmax_simplifier')
opt.local_mul_canonizer.add_simplifier(softmax_simplifier, 'softmax_simplifier')
if 0:
def softmax_grad_simplifier(numerators, denominators): def softmax_grad_simplifier(numerators, denominators):
print "mul simplify numerators" print "mul simplify numerators"
printing.debugprint(numerators) printing.debugprint(numerators)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论