提交 c7d06ac9 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

pep8 nnet

上级 67eba77c
...@@ -80,7 +80,7 @@ class SoftmaxWithBias(gof.Op): ...@@ -80,7 +80,7 @@ class SoftmaxWithBias(gof.Op):
g_sm, = grads g_sm, = grads
if isinstance(g_sm.type, DisconnectedType): if isinstance(g_sm.type, DisconnectedType):
return [ DisconnectedType()(), DisconnectedType()() ] return [DisconnectedType()(), DisconnectedType()()]
sm = softmax_with_bias(x, b) sm = softmax_with_bias(x, b)
dx = softmax_grad(g_sm, sm) dx = softmax_grad(g_sm, sm)
...@@ -561,8 +561,8 @@ if 0: ...@@ -561,8 +561,8 @@ if 0:
axis = ds_input.owner.op.axis axis = ds_input.owner.op.axis
sum_input = ds_input.owner.inputs[0] sum_input = ds_input.owner.inputs[0]
if ((ds_order!=(0,'x')) or if ((ds_order != (0, 'x')) or
(axis!=(1,)) or (axis != (1,)) or
(sum_input is not prod_term)): (sum_input is not prod_term)):
rest.append(add_in) rest.append(add_in)
#print 'ds_order =', ds_order #print 'ds_order =', ds_order
...@@ -715,20 +715,18 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op): ...@@ -715,20 +715,18 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op):
def connection_pattern(self, node): def connection_pattern(self, node):
return [[True,True,True],#x return [[True, True, True], # x
[True,True,True],#b [True, True, True], # b
[False,False,True]]#y_idx [False, False, True]] # y_idx
def grad(self, inp, grads): def grad(self, inp, grads):
x, b, y_idx = inp x, b, y_idx = inp
g_nll, g_sm, g_am = grads g_nll, g_sm, g_am = grads
dx_terms = [] dx_terms = []
db_terms = [] db_terms = []
d_idx_terms = [] d_idx_terms = []
if not isinstance(g_nll.type, DisconnectedType): if not isinstance(g_nll.type, DisconnectedType):
nll, sm = crossentropy_softmax_1hot_with_bias(x, b, y_idx) nll, sm = crossentropy_softmax_1hot_with_bias(x, b, y_idx)
dx = crossentropy_softmax_1hot_with_bias_dx(g_nll, sm, y_idx) dx = crossentropy_softmax_1hot_with_bias_dx(g_nll, sm, y_idx)
...@@ -746,7 +744,7 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op): ...@@ -746,7 +744,7 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op):
db_terms.append(b.zeros_like()) db_terms.append(b.zeros_like())
d_idx_terms.append(y_idx.zeros_like()) d_idx_terms.append(y_idx.zeros_like())
def fancy_sum( terms ): def fancy_sum(terms):
if len(terms) == 0: if len(terms) == 0:
return DisconnectedType()() return DisconnectedType()()
rval = terms[0] rval = terms[0]
...@@ -754,8 +752,8 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op): ...@@ -754,8 +752,8 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op):
rval = rval + term rval = rval + term
return rval return rval
return [ fancy_sum(terms) for terms in return [fancy_sum(terms) for terms in
[dx_terms, db_terms, d_idx_terms ] ] [dx_terms, db_terms, d_idx_terms]]
def c_headers(self): def c_headers(self):
return ['<iostream>', '<cmath>'] return ['<iostream>', '<cmath>']
...@@ -1332,7 +1330,6 @@ def local_advanced_indexing_crossentropy_onehot(node): ...@@ -1332,7 +1330,6 @@ def local_advanced_indexing_crossentropy_onehot(node):
except Exception: except Exception:
pass pass
if sm is not None and sm.owner and sm.owner.op in (softmax, if sm is not None and sm.owner and sm.owner.op in (softmax,
softmax_with_bias): softmax_with_bias):
sm_w_bias = local_softmax_with_bias.transform(sm.owner) sm_w_bias = local_softmax_with_bias.transform(sm.owner)
...@@ -1488,7 +1485,8 @@ def local_advanced_indexing_crossentropy_onehot_grad(node): ...@@ -1488,7 +1485,8 @@ def local_advanced_indexing_crossentropy_onehot_grad(node):
if adv_subtensor is not None: if adv_subtensor is not None:
try: try:
maybe_sm, maybe_rows, maybe_labels = adv_subtensor.owner.inputs maybe_sm, maybe_rows,
maybe_labels = adv_subtensor.owner.inputs
except Exception: except Exception:
return return
...@@ -1698,7 +1696,6 @@ class Prepend_scalar_constant_to_each_row(gof.Op): ...@@ -1698,7 +1696,6 @@ class Prepend_scalar_constant_to_each_row(gof.Op):
shp = (in_shapes[0][0], in_shapes[0][1] + 1) shp = (in_shapes[0][0], in_shapes[0][1] + 1)
return [shp] return [shp]
def grad(self, inp, grads): def grad(self, inp, grads):
mat, = inp mat, = inp
goutput, = grads goutput, = grads
...@@ -1765,18 +1762,19 @@ prepend_1_to_each_row = Prepend_scalar_constant_to_each_row(1.) ...@@ -1765,18 +1762,19 @@ prepend_1_to_each_row = Prepend_scalar_constant_to_each_row(1.)
#numerically stabilize log softmax (X) #numerically stabilize log softmax (X)
# as X-X.max(axis=1).dimshuffle(0,'x') - log(exp(X-X.max(axis=1).dimshuffle(0,'x')).sum(axis=1)).dimshuffle(0,'x) # as X-X.max(axis=1).dimshuffle(0,'x') - log(exp(X-X.max(axis=1).dimshuffle(0,'x')).sum(axis=1)).dimshuffle(0,'x)
def make_out_pattern(X): def make_out_pattern(X):
stabilized_X = X - X.max(axis=1).dimshuffle(0,'x') stabilized_X = X - X.max(axis=1).dimshuffle(0, 'x')
out_var = stabilized_X - tensor.log(tensor.exp(stabilized_X).sum(axis=1)).dimshuffle(0,'x') out_var = stabilized_X - tensor.log(tensor.exp(stabilized_X).sum(
axis=1)).dimshuffle(0, 'x')
#tell DEBUG_MODE that it's OK if the original graph produced NaN and the optimized graph does not #tell DEBUG_MODE that it's OK if the original graph produced NaN and the optimized graph does not
out_var.values_eq_approx = out_var.type.values_eq_approx_remove_nan out_var.values_eq_approx = out_var.type.values_eq_approx_remove_nan
return out_var return out_var
local_log_softmax = gof.PatternSub( in_pattern = (tensor.log, (softmax, 'x')), local_log_softmax = gof.PatternSub(in_pattern=(tensor.log, (softmax, 'x')),
out_pattern = (make_out_pattern, 'x'), out_pattern=(make_out_pattern, 'x'),
allow_multiple_clients=True) allow_multiple_clients=True)
#don't do register_stabilize, this is to make local_log_softmax run #don't do register_stabilize, this is to make local_log_softmax run
#only after another more specific optimization that stabilizes cross entropy #only after another more specific optimization that stabilizes cross entropy
#opt.register_stabilize(local_log_softmax, name = 'local_log_softmax') #opt.register_stabilize(local_log_softmax, name = 'local_log_softmax')
opt.register_specialize(local_log_softmax, name = 'local_log_softmax') opt.register_specialize(local_log_softmax, name='local_log_softmax')
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论