提交 de37ce49 authored 作者: James Bergstra's avatar James Bergstra

Adding commented-out softmax optimizations. They are under construction.

上级 0257b2b6
...@@ -397,6 +397,83 @@ def local_softmax_with_bias(node): ...@@ -397,6 +397,83 @@ def local_softmax_with_bias(node):
return return
return [sm_bias] return [sm_bias]
if 0:
def softmax_simplifier(numerators, denominators):
for numerator in list(numerators):
#TODO: a single softmax'd vector??
if not numerator.type.dtype.startswith('float'):
continue
if not numerator.type.broadcastable == (False, False):
continue
if numerator.owner and numerator.owner.op == tensor.exp:
x = numerator.owner.inputs[0]
else:
continue
matching_denom = None
for denominator in denominators:
if denominator.owner and isinstance(denominator.owner.op, tensor.DimShuffle):
if denominator.owner.op.new_order == (0,'x'):
z = denominator.owner.inputs[0] # thing getting dimshuffled
if z.owner and isinstance(z.owner.op, tensor.Sum):
print 'ASDF', denominator.owner.op.new_order
print z.owner.op.axis
if z.owner.op.axis == (1,):
print "almost there.. softmax", x, z.owner.inputs[0]
if z.owner.inputs[0] is numerator:
matching_denom = denominator
break
if matching_denom:
numerators.remove(numerator)
denominators.remove(matching_denom)
numerators.append(softmax(x))
return numerators, denominators
opt.local_mul_canonizer.add_simplifier(softmax_simplifier, 'softmax_simplifier')
def softmax_grad_simplifier(numerators, denominators):
print "mul simplify numerators"
debugprint(numerators)
print "mul simplify denominators"
debugprint(denominators)
for numerator in list(numerators):
#TODO: a single softmax'd vector??
if not numerator.type.dtype.startswith('float'):
continue
if not numerator.type.broadcastable == (False, False):
continue
if numerator.owner and numerator.owner.op == tensor.exp:
x = numerator.owner.inputs[0]
else:
continue
print "A", denominators
matching_denom = None
for denominator in denominators:
if denominator.owner and denominator.owner.op == tensor.add:
if len(denominator.owner.inputs)==2:
dl,dr = denominator.owner.inputs
# check to see if either dl or dr is softmax(x)
# if yes, we are probably dealing with the gradient of softmax
other=None
if dl.owner and dl.owner.op == softmax and dl.owner.inputs[0]==x:
other=dr
if dr.owner and dr.owner.op == softmax and dr.owner.inputs[0]==x:
other=dl
if other:
print "OTHER", other
if matching_denom:
numerators.remove(numerator)
denominators.remove(matching_denom)
numerators.append(softmax(x))
return numerators, denominators
opt.local_mul_canonizer.add_simplifier(softmax_grad_simplifier, 'softmax_grad_simplifier')
class CrossentropySoftmaxArgmax1HotWithBias(gof.Op): class CrossentropySoftmaxArgmax1HotWithBias(gof.Op):
"""A special compound L{Op} for the output of neural-net classifiers. """A special compound L{Op} for the output of neural-net classifiers.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论