提交 e05036f0 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #2847 from harmdevries89/softmax_iss2050

[MRG] softmax function that builds expression instead of using softmax op
...@@ -739,7 +739,7 @@ class T_Scan(unittest.TestCase): ...@@ -739,7 +739,7 @@ class T_Scan(unittest.TestCase):
def forward_scanner(x_t): def forward_scanner(x_t):
a2_t = tensor.dot(x_t, W) a2_t = tensor.dot(x_t, W)
y_t = tensor.nnet.softmax(a2_t) y_t = tensor.nnet.softmax_graph(a2_t)
return y_t return y_t
y, _ = theano.scan(fn=forward_scanner, sequences=x, y, _ = theano.scan(fn=forward_scanner, sequences=x,
......
...@@ -78,12 +78,17 @@ class SoftmaxWithBias(gof.Op): ...@@ -78,12 +78,17 @@ class SoftmaxWithBias(gof.Op):
if b.shape[0] != x.shape[1]: if b.shape[0] != x.shape[1]:
raise ValueError('b must have same number of columns as x') raise ValueError('b must have same number of columns as x')
sm = numpy.zeros_like(x) # sm = numpy.zeros_like(x)
for i in xrange(sm.shape[0]): # for i in xrange(sm.shape[0]):
row = x[i] + b # row = x[i] + b
sm[i] = numpy.exp(row - numpy.max(row)) # sm[i] = numpy.exp(row - numpy.max(row))
sm[i] *= 1.0 / numpy.sum(sm[i]) # sm[i] *= 1.0 / numpy.sum(sm[i])
output_storage[0][0] = sm # output_storage[0][0] = sm
x_plus_b = x + b[None, :]
e_x = numpy.exp(x_plus_b - x_plus_b.max(axis=1)[:, None])
e_x *= 1.0 / e_x.sum(axis=1)[:, None]
output_storage[0][0] = e_x
def grad(self, inp, grads): def grad(self, inp, grads):
x, b = inp x, b = inp
...@@ -304,8 +309,17 @@ class SoftmaxGrad(gof.Op): ...@@ -304,8 +309,17 @@ class SoftmaxGrad(gof.Op):
dx[i] = dy_times_sm_i - sum(dy_times_sm_i) * sm[i] dx[i] = dy_times_sm_i - sum(dy_times_sm_i) * sm[i]
output_storage[0][0] = dx output_storage[0][0] = dx
def grad(self, *args): def grad(self, inp, grads):
raise NotImplementedError() dy, sm = inp
g, = grads
tmp = g + tensor.neg(tensor.sum(g*sm, axis=1).dimshuffle((0, 'x')))
g_dy = tmp * sm
tmp2 = tensor.sum(dy*sm, axis=1).dimshuffle((0, 'x'))
g_sm = tmp*dy - g *tmp2
return g_dy, g_sm
def infer_shape(self, node, shape): def infer_shape(self, node, shape):
return [shape[1]] return [shape[1]]
...@@ -414,7 +428,7 @@ class Softmax(gof.Op): ...@@ -414,7 +428,7 @@ class Softmax(gof.Op):
def grad(self, inp, grads): def grad(self, inp, grads):
x, = inp x, = inp
g_sm, = grads g_sm, = grads
sm = softmax(x) sm = softmax_op(x)
return [softmax_grad(g_sm, sm)] return [softmax_grad(g_sm, sm)]
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
...@@ -568,15 +582,20 @@ class Softmax(gof.Op): ...@@ -568,15 +582,20 @@ class Softmax(gof.Op):
def c_code_cache_version(): def c_code_cache_version():
return (3,) return (3,)
softmax = Softmax() softmax_op = Softmax()
def softmax_graph(c):
return tensor.exp(c) / tensor.exp(c).sum(axis=-1, keepdims=True)
def softmax(c):
return softmax_op(c)
@opt.register_specialize('fast_compile_gpu') @opt.register_specialize('fast_compile_gpu')
@gof.local_optimizer([softmax]) @gof.local_optimizer([softmax_op])
def local_softmax_with_bias(node): def local_softmax_with_bias(node):
"""Try to turn softmax(sum_of_stuff) -> softmax_w_bias(matrix, bias) """Try to turn softmax(sum_of_stuff) -> softmax_w_bias(matrix, bias)
""" """
if node.op == softmax: if node.op == softmax_op:
x, = node.inputs x, = node.inputs
if x.owner and x.owner.op == tensor.add: if x.owner and x.owner.op == tensor.add:
vectors = [] vectors = []
...@@ -638,7 +657,7 @@ def softmax_simplifier(numerators, denominators): ...@@ -638,7 +657,7 @@ def softmax_simplifier(numerators, denominators):
if not numerator.type.dtype.startswith('float'): if not numerator.type.dtype.startswith('float'):
continue continue
if not numerator.type.broadcastable == (False, False): if numerator.ndim != 2:
continue continue
if numerator.owner and numerator.owner.op == tensor.exp: if numerator.owner and numerator.owner.op == tensor.exp:
x = numerator.owner.inputs[0] x = numerator.owner.inputs[0]
...@@ -664,7 +683,8 @@ def softmax_simplifier(numerators, denominators): ...@@ -664,7 +683,8 @@ def softmax_simplifier(numerators, denominators):
if matching_denom: if matching_denom:
numerators.remove(numerator) numerators.remove(numerator)
denominators.remove(matching_denom) denominators.remove(matching_denom)
numerators.append(softmax(x)) numerators.append(softmax_op(x))
return numerators, denominators return numerators, denominators
opt.local_mul_canonizer.add_simplifier(softmax_simplifier, opt.local_mul_canonizer.add_simplifier(softmax_simplifier,
'softmax_simplifier') 'softmax_simplifier')
...@@ -1404,7 +1424,7 @@ def crossentropy_to_crossentropy_with_softmax(fgraph): ...@@ -1404,7 +1424,7 @@ def crossentropy_to_crossentropy_with_softmax(fgraph):
if node.op == crossentropy_categorical_1hot: if node.op == crossentropy_categorical_1hot:
nll, = node.outputs nll, = node.outputs
sm, one_of_n = node.inputs sm, one_of_n = node.inputs
if sm.owner and sm.owner.op == softmax: if sm.owner and sm.owner.op == softmax_op:
x, = sm.owner.inputs x, = sm.owner.inputs
new_nll, new_sm, new_am = crossentropy_softmax_argmax_1hot_with_bias(x, new_nll, new_sm, new_am = crossentropy_softmax_argmax_1hot_with_bias(x,
tensor.zeros_like(x[0]), one_of_n) tensor.zeros_like(x[0]), one_of_n)
...@@ -1450,7 +1470,7 @@ def local_softmax_grad_to_crossentropy_with_softmax_grad(node): ...@@ -1450,7 +1470,7 @@ def local_softmax_grad_to_crossentropy_with_softmax_grad(node):
def local_argmax_pushdown(node): def local_argmax_pushdown(node):
if node.op == tensor._max_and_argmax and node.inputs[0].owner and \ if node.op == tensor._max_and_argmax and node.inputs[0].owner and \
len(node.outputs[0].clients) > 0 and node.inputs[0].owner.op in \ len(node.outputs[0].clients) > 0 and node.inputs[0].owner.op in \
(softmax, softplus, tensor.exp, tensor.log, tensor.tanh, sigmoid, (softmax_op, softplus, tensor.exp, tensor.log, tensor.tanh, sigmoid,
softmax_with_bias): softmax_with_bias):
if theano.config.warn.argmax_pushdown_bug: if theano.config.warn.argmax_pushdown_bug:
logging.getLogger('theano.tensor.nnet.nnet').warn("WARNING: there " logging.getLogger('theano.tensor.nnet.nnet').warn("WARNING: there "
...@@ -1466,7 +1486,7 @@ def local_argmax_pushdown(node): ...@@ -1466,7 +1486,7 @@ def local_argmax_pushdown(node):
x_max, x_argmax = node.outputs x_max, x_argmax = node.outputs
x, axis = node.inputs x, axis = node.inputs
# TODO: Make a list/set of monotonic ops... # TODO: Make a list/set of monotonic ops...
if x.owner and x.owner.op in (softmax, softplus, tensor.exp, if x.owner and x.owner.op in (softmax_op, softplus, tensor.exp,
tensor.log, tensor.tanh, sigmoid): tensor.log, tensor.tanh, sigmoid):
pre_x, = x.owner.inputs pre_x, = x.owner.inputs
return tensor._max_and_argmax(pre_x, axis) return tensor._max_and_argmax(pre_x, axis)
...@@ -1554,7 +1574,7 @@ def local_advanced_indexing_crossentropy_onehot(node): ...@@ -1554,7 +1574,7 @@ def local_advanced_indexing_crossentropy_onehot(node):
except Exception: except Exception:
pass pass
if sm is not None and sm.owner and sm.owner.op in (softmax, if sm is not None and sm.owner and sm.owner.op in (softmax_op,
softmax_with_bias): softmax_with_bias):
sm_w_bias = local_softmax_with_bias.transform(sm.owner) sm_w_bias = local_softmax_with_bias.transform(sm.owner)
if sm_w_bias: if sm_w_bias:
...@@ -1584,7 +1604,7 @@ def local_advanced_indexing_crossentropy_onehot_grad(node): ...@@ -1584,7 +1604,7 @@ def local_advanced_indexing_crossentropy_onehot_grad(node):
except Exception: except Exception:
return return
if (sm is not None) and sm.owner and (sm.owner.op in (softmax, if (sm is not None) and sm.owner and (sm.owner.op in (softmax_op,
softmax_with_bias)): softmax_with_bias)):
sm_w_bias = local_softmax_with_bias.transform(sm.owner) sm_w_bias = local_softmax_with_bias.transform(sm.owner)
if sm_w_bias: if sm_w_bias:
...@@ -2054,7 +2074,7 @@ def make_out_pattern(X): ...@@ -2054,7 +2074,7 @@ def make_out_pattern(X):
return out_var return out_var
local_log_softmax = gof.PatternSub(in_pattern=(tensor.log, (softmax, 'x')), local_log_softmax = gof.PatternSub(in_pattern=(tensor.log, (softmax_op, 'x')),
out_pattern=(make_out_pattern, 'x'), out_pattern=(make_out_pattern, 'x'),
allow_multiple_clients=True) allow_multiple_clients=True)
......
...@@ -22,8 +22,8 @@ from theano.tensor.nnet import (categorical_crossentropy, ...@@ -22,8 +22,8 @@ from theano.tensor.nnet import (categorical_crossentropy,
CrossentropySoftmaxArgmax1HotWithBias, CrossentropySoftmaxArgmax1HotWithBias,
CrossentropyCategorical1Hot, CrossentropyCategorical1Hot,
CrossentropyCategorical1HotGrad, CrossentropyCategorical1HotGrad,
sigmoid, softplus, sigmoid, softplus, Softmax, softmax,
Softmax, softmax, SoftmaxWithBias, softmax_op, softmax_graph, SoftmaxWithBias,
softmax_grad, softmax_grad,
softmax_with_bias, SoftmaxGrad, softmax_with_bias, SoftmaxGrad,
Prepend_scalar_constant_to_each_row, Prepend_scalar_constant_to_each_row,
...@@ -54,40 +54,40 @@ class T_Softmax(utt.InferShapeTester): ...@@ -54,40 +54,40 @@ class T_Softmax(utt.InferShapeTester):
def test0(self): def test0(self):
def f(a): def f(a):
return softmax(a)[:, 0] return softmax_op(a)[:, 0]
utt.verify_grad(f, [numpy.random.rand(3, 4)]) utt.verify_grad(f, [numpy.random.rand(3, 4)])
def test1(self): def test1(self):
def f(a): def f(a):
return softmax(a)[:, 1] return softmax_op(a)[:, 1]
utt.verify_grad(f, [numpy.random.rand(3, 4)]) utt.verify_grad(f, [numpy.random.rand(3, 4)])
def test2(self): def test2(self):
def f(a): def f(a):
return softmax(a)[:, 2] return softmax_op(a)[:, 2]
utt.verify_grad(f, [numpy.random.rand(3, 4)]) utt.verify_grad(f, [numpy.random.rand(3, 4)])
def test3(self): def test3(self):
def f(a): def f(a):
return softmax(a)[:, 3] return softmax_op(a)[:, 3]
utt.verify_grad(f, [numpy.random.rand(3, 4)]) utt.verify_grad(f, [numpy.random.rand(3, 4)])
def test_infer_shape(self): def test_infer_shape(self):
admat = matrix() admat = matrix()
admat_val = numpy.random.rand(3, 4).astype(config.floatX) admat_val = numpy.random.rand(3, 4).astype(config.floatX)
self._compile_and_check([admat], [Softmax()(admat)], self._compile_and_check([admat], [Softmax()(admat)],
[admat_val], Softmax) [admat_val], Softmax)
def test_vector(self): def test_vector(self):
x = T.vector() x = T.vector()
f = theano.function([x], softmax(x)) f = theano.function([x], softmax_op(x))
xv = numpy.random.randn(6).astype(config.floatX) xv = numpy.random.randn(6).astype(config.floatX)
assert numpy.allclose(f(xv), numpy.exp(xv) / numpy.exp(xv).sum()) assert numpy.allclose(f(xv), numpy.exp(xv) / numpy.exp(xv).sum())
def test_vector_grad(self): def test_vector_grad(self):
def f(a): def f(a):
return softmax(a) return softmax_op(a)
utt.verify_grad(f, [numpy.random.rand(4)]) utt.verify_grad(f, [numpy.random.rand(4)])
...@@ -129,10 +129,10 @@ class T_SoftmaxWithBias(utt.InferShapeTester): ...@@ -129,10 +129,10 @@ class T_SoftmaxWithBias(utt.InferShapeTester):
vbias = theano.shared(value=0.1, name='vbias') # 0.01 vbias = theano.shared(value=0.1, name='vbias') # 0.01
hid = T.vector('hid') hid = T.vector('hid')
f = theano.function([hid], f = theano.function([hid],
T.nnet.softmax(T.dot(hid, W.T) + vbias)) T.nnet.softmax_op(T.dot(hid, W.T) + vbias))
ops = [node.op for node in f.maker.fgraph.toposort()] ops = [node.op for node in f.maker.fgraph.toposort()]
assert softmax_with_bias not in ops assert softmax_with_bias not in ops
assert softmax in ops assert softmax_op in ops
f([0, 1, 0]) f([0, 1, 0])
# print f.maker.fgraph.toposort() # print f.maker.fgraph.toposort()
...@@ -400,7 +400,7 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -400,7 +400,7 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
fgraph = gof.FunctionGraph( fgraph = gof.FunctionGraph(
[x, one_of_n], [x, one_of_n],
[op(softmax(x), one_of_n)]) [op(softmax_op(x), one_of_n)])
assert fgraph.outputs[0].owner.op == op assert fgraph.outputs[0].owner.op == op
theano.compile.mode.optdb.query( theano.compile.mode.optdb.query(
...@@ -416,7 +416,7 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -416,7 +416,7 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
op = crossentropy_categorical_1hot op = crossentropy_categorical_1hot
fgraph = gof.FunctionGraph( fgraph = gof.FunctionGraph(
[x, one_of_n], [x, one_of_n],
[op(softmax(x), one_of_n)]) [op(softmax_op(x), one_of_n)])
assert fgraph.outputs[0].owner.op == op assert fgraph.outputs[0].owner.op == op
theano.compile.mode.optdb.query( theano.compile.mode.optdb.query(
...@@ -434,7 +434,7 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -434,7 +434,7 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
fgraph = gof.FunctionGraph( fgraph = gof.FunctionGraph(
[x, b, one_of_n], [x, b, one_of_n],
[op(softmax(x + b), one_of_n)]) [op(softmax_op(x + b), one_of_n)])
assert fgraph.outputs[0].owner.op == op assert fgraph.outputs[0].owner.op == op
# print 'BEFORE' # print 'BEFORE'
...@@ -466,7 +466,7 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -466,7 +466,7 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
fgraph = gof.FunctionGraph( fgraph = gof.FunctionGraph(
[x, b, c, one_of_n], [x, b, c, one_of_n],
[op(softmax(T.add(x, b, c)), one_of_n)]) [op(softmax_op(T.add(x, b, c)), one_of_n)])
assert fgraph.outputs[0].owner.op == op assert fgraph.outputs[0].owner.op == op
# print 'BEFORE' # print 'BEFORE'
...@@ -494,7 +494,7 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -494,7 +494,7 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
op = crossentropy_categorical_1hot op = crossentropy_categorical_1hot
fgraph = gof.FunctionGraph( fgraph = gof.FunctionGraph(
[x, b, one_of_n], [x, b, one_of_n],
[op(softmax(x + b), one_of_n)]) [op(softmax_op(x + b), one_of_n)])
assert fgraph.outputs[0].owner.op == op assert fgraph.outputs[0].owner.op == op
# print 'BEFORE' # print 'BEFORE'
# for node in fgraph.toposort(): # for node in fgraph.toposort():
...@@ -517,7 +517,7 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -517,7 +517,7 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
x = tensor.matrix('x') x = tensor.matrix('x')
one_of_n = tensor.lvector('one_of_n') one_of_n = tensor.lvector('one_of_n')
op = crossentropy_categorical_1hot op = crossentropy_categorical_1hot
xe = op(softmax(x), one_of_n) xe = op(softmax_op(x), one_of_n)
sum_xe = tensor.sum(xe) sum_xe = tensor.sum(xe)
g_x = tensor.grad(sum_xe, x) g_x = tensor.grad(sum_xe, x)
fgraph = gof.FunctionGraph( fgraph = gof.FunctionGraph(
...@@ -546,7 +546,7 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -546,7 +546,7 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
has_cx1hot = True has_cx1hot = True
if node.op == crossentropy_softmax_1hot_with_bias_dx: if node.op == crossentropy_softmax_1hot_with_bias_dx:
has_cx1hotdx = True has_cx1hotdx = True
if node.op == softmax: if node.op == softmax_op:
has_softmax = True has_softmax = True
if node.op == softmax_grad: if node.op == softmax_grad:
has_softmaxdx = True has_softmaxdx = True
...@@ -559,7 +559,7 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -559,7 +559,7 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
x = tensor.vector('x') x = tensor.vector('x')
one_of_n = tensor.lvector('one_of_n') one_of_n = tensor.lvector('one_of_n')
op = crossentropy_categorical_1hot op = crossentropy_categorical_1hot
xe = op(softmax(x), one_of_n) xe = op(softmax_op(x), one_of_n)
sum_xe = tensor.sum(xe) sum_xe = tensor.sum(xe)
g_x = tensor.grad(sum_xe, x) g_x = tensor.grad(sum_xe, x)
fgraph = gof.FunctionGraph( fgraph = gof.FunctionGraph(
...@@ -588,7 +588,7 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -588,7 +588,7 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
has_cx1hot = True has_cx1hot = True
if node.op == crossentropy_softmax_1hot_with_bias_dx: if node.op == crossentropy_softmax_1hot_with_bias_dx:
has_cx1hotdx = True has_cx1hotdx = True
if node.op == softmax: if node.op == softmax_op:
has_softmax = True has_softmax = True
if node.op == softmax_grad: if node.op == softmax_grad:
has_softmaxdx = True has_softmaxdx = True
...@@ -643,7 +643,7 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -643,7 +643,7 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
ops = [node.op for node in g.maker.fgraph.toposort()] ops = [node.op for node in g.maker.fgraph.toposort()]
assert len(ops) == 2 assert len(ops) == 2
assert crossentropy_softmax_1hot_with_bias_dx in ops assert crossentropy_softmax_1hot_with_bias_dx in ops
assert softmax in ops assert softmax_op in ops
assert softmax_grad not in ops assert softmax_grad not in ops
g(x_val, y_val) g(x_val, y_val)
except Exception: except Exception:
...@@ -714,7 +714,7 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -714,7 +714,7 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
#there's an extra dimshuffle in there #there's an extra dimshuffle in there
# but I can't think of a good rule to get rid of it # but I can't think of a good rule to get rid of it
assert crossentropy_softmax_1hot_with_bias_dx in ops assert crossentropy_softmax_1hot_with_bias_dx in ops
assert softmax in ops assert softmax_op in ops
assert softmax_grad not in ops assert softmax_grad not in ops
g(x_val, y_val) g(x_val, y_val)
except Exception: except Exception:
...@@ -796,7 +796,7 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -796,7 +796,7 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
ops = [node.op for node in g.maker.fgraph.toposort()] ops = [node.op for node in g.maker.fgraph.toposort()]
assert len(ops) == 3 assert len(ops) == 3
assert crossentropy_softmax_1hot_with_bias_dx in ops assert crossentropy_softmax_1hot_with_bias_dx in ops
assert softmax in ops assert softmax_op in ops
assert softmax_grad not in ops assert softmax_grad not in ops
g(x_val, y_val) g(x_val, y_val)
except Exception: except Exception:
...@@ -841,7 +841,7 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -841,7 +841,7 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
ops = [node.op for node in g.maker.fgraph.toposort()] ops = [node.op for node in g.maker.fgraph.toposort()]
assert len(ops) == 4 assert len(ops) == 4
assert crossentropy_softmax_1hot_with_bias_dx in ops assert crossentropy_softmax_1hot_with_bias_dx in ops
assert softmax in ops assert softmax_op in ops
assert softmax_grad not in ops assert softmax_grad not in ops
g(x_val, y_val) g(x_val, y_val)
except Exception: except Exception:
...@@ -1028,7 +1028,7 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -1028,7 +1028,7 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
theano.printing.debugprint(g) theano.printing.debugprint(g)
raise raise
def test_scale_cost(self): def test_crossentropy_softmax_1hot_with_bias_dxcale_cost(self):
# TODO: add the optimization in FAST_COMPILE? # TODO: add the optimization in FAST_COMPILE?
# In the mean time, run it as 'FAST_RUN' instead # In the mean time, run it as 'FAST_RUN' instead
mode = theano.compile.mode.get_default_mode() mode = theano.compile.mode.get_default_mode()
...@@ -1048,7 +1048,7 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -1048,7 +1048,7 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
for node in func.maker.fgraph.toposort(): for node in func.maker.fgraph.toposort():
if node.op == crossentropy_softmax_argmax_1hot_with_bias: if node.op == crossentropy_softmax_argmax_1hot_with_bias:
has_cx1hot = True has_cx1hot = True
if node.op == softmax: if node.op == softmax_op:
has_softmax = True has_softmax = True
assert has_cx1hot assert has_cx1hot
...@@ -1062,7 +1062,7 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -1062,7 +1062,7 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
for node in func.maker.fgraph.toposort(): for node in func.maker.fgraph.toposort():
if node.op == crossentropy_softmax_1hot_with_bias_dx: if node.op == crossentropy_softmax_1hot_with_bias_dx:
has_cx1hotdx = True has_cx1hotdx = True
if node.op == softmax: if node.op == softmax_op:
has_softmax = True has_softmax = True
if node.op == softmax_grad: if node.op == softmax_grad:
has_softmaxdx = True has_softmaxdx = True
...@@ -1129,49 +1129,49 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -1129,49 +1129,49 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
def test_argmax_pushdown(): def test_argmax_pushdown():
x = tensor.matrix() x = tensor.matrix()
for softmax in [softmax_graph, softmax_op]:
# test that the max_and_argmax is pushed down if the max is not used
out = tensor.max_and_argmax(
softmax(tensor.exp(tensor.tanh(sigmoid(x)))),
axis=-1)[1]
fgraph = gof.FunctionGraph(
[x],
[out])
theano.compile.mode.optdb.query(
theano.compile.mode.OPT_FAST_RUN).optimize(fgraph)
# test that the max_and_argmax is pushed down if the max is not used # print 'AFTER'
out = tensor.max_and_argmax( # for node in fgraph.toposort():
softmax(tensor.exp(tensor.tanh(sigmoid(x)))), # print node.op
axis=-1)[1] assert len(fgraph.toposort()) == 2 # an output_guard is second
fgraph = gof.FunctionGraph( assert fgraph.toposort()[0].op == tensor.basic._max_and_argmax
[x], assert str(fgraph.toposort()[1].op) == 'OutputGuard'
[out]) x = tensor.matrix()
theano.compile.mode.optdb.query( # test that the max_and_argmax is not pushed down if the max is used
theano.compile.mode.OPT_FAST_RUN).optimize(fgraph) out = tensor.max_and_argmax(
softmax(tensor.exp(tensor.tanh(sigmoid(x)))),
# print 'AFTER' axis=-1)[0]
# for node in fgraph.toposort(): fgraph = gof.FunctionGraph(
# print node.op [x],
assert len(fgraph.toposort()) == 2 # an output_guard is second [out])
assert fgraph.toposort()[0].op == tensor.basic._max_and_argmax
assert str(fgraph.toposort()[1].op) == 'OutputGuard'
x = tensor.matrix()
# test that the max_and_argmax is not pushed down if the max is used
out = tensor.max_and_argmax(
softmax(tensor.exp(tensor.tanh(sigmoid(x)))),
axis=-1)[0]
fgraph = gof.FunctionGraph(
[x],
[out])
backup = config.warn.argmax_pushdown_bug backup = config.warn.argmax_pushdown_bug
config.warn.argmax_pushdown_bug = False config.warn.argmax_pushdown_bug = False
try: try:
theano.compile.mode.optdb.query( theano.compile.mode.optdb.query(
theano.compile.mode.OPT_FAST_RUN).optimize(fgraph) theano.compile.mode.OPT_FAST_RUN).optimize(fgraph)
finally: finally:
config.warn.argmax_pushdown_bug = backup config.warn.argmax_pushdown_bug = backup
# print 'AFTER' # print 'AFTER'
# for node in fgraph.toposort(): # for node in fgraph.toposort():
# print node.op # print node.op
assert len(fgraph.toposort()) == 4 # an output_guard is second assert len(fgraph.toposort()) == 4 # an output_guard is second
assert isinstance(fgraph.toposort()[0].op, tensor.Elemwise) assert isinstance(fgraph.toposort()[0].op, tensor.Elemwise)
assert isinstance(fgraph.toposort()[1].op, Softmax) assert isinstance(fgraph.toposort()[1].op, Softmax)
assert isinstance(fgraph.toposort()[2].op, tensor.CAReduce) assert isinstance(fgraph.toposort()[2].op, tensor.CAReduce)
assert isinstance(fgraph.toposort()[2].op.scalar_op, theano.scalar.Maximum) assert isinstance(fgraph.toposort()[2].op.scalar_op, theano.scalar.Maximum)
assert str(fgraph.toposort()[3].op) == 'OutputGuard' assert str(fgraph.toposort()[3].op) == 'OutputGuard'
def test_argmax_pushdown_bias(): def test_argmax_pushdown_bias():
...@@ -1295,7 +1295,7 @@ class Test_softmax_opt: ...@@ -1295,7 +1295,7 @@ class Test_softmax_opt:
# printing.debugprint(f) # printing.debugprint(f)
# print '===' # print '==='
assert len(f_ops) == 1 assert len(f_ops) == 1
assert softmax in f_ops assert softmax_op in f_ops
f(self.rng.rand(3, 4).astype(config.floatX)) f(self.rng.rand(3, 4).astype(config.floatX))
def test_basic_keepdims(self): def test_basic_keepdims(self):
...@@ -1309,7 +1309,7 @@ class Test_softmax_opt: ...@@ -1309,7 +1309,7 @@ class Test_softmax_opt:
# printing.debugprint(f) # printing.debugprint(f)
# print '===' # print '==='
assert len(f_ops) == 1 assert len(f_ops) == 1
assert softmax in f_ops assert softmax_op in f_ops
f(self.rng.rand(3, 4).astype(config.floatX)) f(self.rng.rand(3, 4).astype(config.floatX))
def test_grad(self): def test_grad(self):
...@@ -1331,7 +1331,7 @@ class Test_softmax_opt: ...@@ -1331,7 +1331,7 @@ class Test_softmax_opt:
raise SkipTest('Optimization not enabled for the moment') raise SkipTest('Optimization not enabled for the moment')
assert len(g_ops) == 2 assert len(g_ops) == 2
assert softmax in g_ops assert softmax_op in g_ops
assert softmax_grad in g_ops assert softmax_grad in g_ops
g(self.rng.rand(3, 4), self.rng.uniform(.5, 1, (3, 4))) g(self.rng.rand(3, 4), self.rng.uniform(.5, 1, (3, 4)))
...@@ -1377,12 +1377,33 @@ class Test_softmax_opt: ...@@ -1377,12 +1377,33 @@ class Test_softmax_opt:
# etc. # etc.
def test_softmax_graph():
rng = numpy.random.RandomState(utt.fetch_seed())
x = theano.shared(rng.normal(size=(3, 4)))
def f(inputs):
y = softmax_graph(x)
return theano.grad(None, x, known_grads={y: inputs})
utt.verify_grad(f, [rng.rand(3, 4)])
def test_grad_softmax_grad():
rng = numpy.random.RandomState(utt.fetch_seed())
x = theano.shared(rng.normal(size=(3, 4)))
def f(inputs):
y = softmax_op(x)
return theano.grad(None, x, known_grads={y: inputs})
utt.verify_grad(f, [rng.rand(3, 4)])
def test_stabilize_log_softmax(): def test_stabilize_log_softmax():
mode = theano.compile.mode.get_default_mode() mode = theano.compile.mode.get_default_mode()
mode = mode.including('local_log_softmax', 'specialize') mode = mode.including('local_log_softmax', 'specialize')
x = matrix() x = matrix()
y = theano.tensor.nnet.softmax(x) y = softmax(x)
z = theano.tensor.log(y) z = theano.tensor.log(y)
f = theano.function([x], z, mode=mode) f = theano.function([x], z, mode=mode)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论