提交 fd7e6d59 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

merge

...@@ -7,6 +7,19 @@ Graph optimization ...@@ -7,6 +7,19 @@ Graph optimization
In this section we will define a couple optimizations on doubles. In this section we will define a couple optimizations on doubles.
.. todo::
This tutorial goes way too far under the hood, for someone who just wants
to add yet another pattern to the libraries in tensor.opt for example.
We need another tutorial that covers the decorator syntax, and explains how
to register your optimization right away. That's what you need to get
going.
Later, the rest is more useful for when that decorator syntax type thing
doesn't work. (There are optimizations that don't fit that model).
Global and local optimizations Global and local optimizations
============================== ==============================
...@@ -119,6 +132,11 @@ simplification described above: ...@@ -119,6 +132,11 @@ simplification described above:
simplify = Simplify() simplify = Simplify()
.. todo::
What is add_requirements? Why would we know to do this? Are there other
requirements we might want to know about?
Here's how it works: first, in ``add_requirements``, we add the Here's how it works: first, in ``add_requirements``, we add the
``ReplaceValidate`` :ref:`envfeature` located in ``ReplaceValidate`` :ref:`envfeature` located in
:api:`theano.gof.toolbox`. This feature adds the ``replace_validate`` :api:`theano.gof.toolbox`. This feature adds the ``replace_validate``
...@@ -150,6 +168,7 @@ and :ref:`apply` to get a better understanding of the ...@@ -150,6 +168,7 @@ and :ref:`apply` to get a better understanding of the
pointer-following game you need to get ahold of the nodes of interest pointer-following game you need to get ahold of the nodes of interest
for the simplification (``x``, ``y``, ``z``, ``a``, ``b``, etc.). for the simplification (``x``, ``y``, ``z``, ``a``, ``b``, etc.).
Test time: Test time:
>>> x = double('x') >>> x = double('x')
...@@ -238,6 +257,10 @@ The local version of the above code would be the following: ...@@ -238,6 +257,10 @@ The local version of the above code would be the following:
local_simplify = LocalSimplify() local_simplify = LocalSimplify()
.. todo::
Fix up previous example... it's bad and incomplete.
The definition of transform is the inner loop of the global optimizer, The definition of transform is the inner loop of the global optimizer,
where the node is given as argument. If no changes are to be made, where the node is given as argument. If no changes are to be made,
``False`` must be returned. Else, a list of what to replace the node's ``False`` must be returned. Else, a list of what to replace the node's
...@@ -310,6 +333,9 @@ Theano defines some shortcuts to make LocalOptimizers: ...@@ -310,6 +333,9 @@ Theano defines some shortcuts to make LocalOptimizers:
means that everything we said previously about local optimizers means that everything we said previously about local optimizers
apply: they need to be wrapped in a Navigator, etc. apply: they need to be wrapped in a Navigator, etc.
.. todo::
wtf is a navigator?
When an optimization can be naturally expressed using ``OpSub``, ``OpRemove`` When an optimization can be naturally expressed using ``OpSub``, ``OpRemove``
or ``PatternSub``, it is highly recommended to use them. or ``PatternSub``, it is highly recommended to use them.
...@@ -319,6 +345,7 @@ use constraints, etc. - there's some decent doc at ...@@ -319,6 +345,7 @@ use constraints, etc. - there's some decent doc at
:api:`theano.gof.opt.PatternSub` for those interested) :api:`theano.gof.opt.PatternSub` for those interested)
.. _optdb: .. _optdb:
The optimization database (optdb) The optimization database (optdb)
......
...@@ -1781,36 +1781,44 @@ def div_proxy(x, y): ...@@ -1781,36 +1781,44 @@ def div_proxy(x, y):
return true_div(x, y) return true_div(x, y)
@_scal_elemwise @_scal_elemwise
def add(a, b): def add(a, *other_terms):
"""elementwise addition""" """elementwise addition"""
# see decorator for function body
@_scal_elemwise @_scal_elemwise
def sub(a, b): def sub(a, b):
"""elementwise subtraction""" """elementwise subtraction"""
# see decorator for function body
@_scal_elemwise @_scal_elemwise
def mul(a, b): def mul(a, *other_terms):
"""elementwise multiplication""" """elementwise multiplication"""
# see decorator for function body
@_scal_elemwise @_scal_elemwise
def true_div(a, b): def true_div(a, b):
"""elementwise [true] division (inverse of multiplication)""" """elementwise [true] division (inverse of multiplication)"""
# see decorator for function body
@_scal_elemwise @_scal_elemwise
def int_div(a, b): def int_div(a, b):
"""elementwise integer-division""" """elementwise integer-division"""
# see decorator for function body
@_scal_elemwise @_scal_elemwise
def mod(a, b): def mod(a, b):
"""elementwise modulo""" """elementwise modulo"""
# see decorator for function body
@_scal_elemwise @_scal_elemwise
def pow(a, b): def pow(a, b):
"""elementwise power""" """elementwise power"""
# see decorator for function body
@_scal_elemwise @_scal_elemwise
def clip(x, min, max): def clip(x, min, max):
"""clip x to be between min and max""" """clip x to be between min and max"""
# see decorator for function body
pprint.assign(add, printing.OperatorPrinter('+', -2, 'either')) pprint.assign(add, printing.OperatorPrinter('+', -2, 'either'))
pprint.assign(mul, printing.OperatorPrinter('*', -1, 'either')) pprint.assign(mul, printing.OperatorPrinter('*', -1, 'either'))
...@@ -3007,6 +3015,8 @@ class AdvancedSubtensor(Op): ...@@ -3007,6 +3015,8 @@ class AdvancedSubtensor(Op):
#TODO: see what's the best solution #TODO: see what's the best solution
self.args = args #? self.args = args #?
#FIXME: do not store variables in the class instance
#FIXME #FIXME
#if len(args) != 2: #if len(args) != 2:
# print >>sys.stderr, 'WARNING: Advanced indexing with %i arguments not supported yet' % len(args) # print >>sys.stderr, 'WARNING: Advanced indexing with %i arguments not supported yet' % len(args)
...@@ -3018,6 +3028,11 @@ class AdvancedSubtensor(Op): ...@@ -3018,6 +3028,11 @@ class AdvancedSubtensor(Op):
if x.ndim == 2 and len(inputs) == 2: if x.ndim == 2 and len(inputs) == 2:
ind1 = as_tensor_variable(inputs[0]) ind1 = as_tensor_variable(inputs[0])
ind2 = as_tensor_variable(inputs[1]) ind2 = as_tensor_variable(inputs[1])
if not (ind1.type.dtype.startswith('int') or ind1.type.dtype.startswith('uint')):
raise TypeError()
if not (ind2.type.dtype.startswith('int') or ind2.type.dtype.startswith('uint')):
raise TypeError()
if ind1.ndim == 1 and ind2.ndim == 1: if ind1.ndim == 1 and ind2.ndim == 1:
return gof.Apply(self, return gof.Apply(self,
(x,) + inputs, (x,) + inputs,
...@@ -3029,7 +3044,11 @@ class AdvancedSubtensor(Op): ...@@ -3029,7 +3044,11 @@ class AdvancedSubtensor(Op):
% ','.join(str(input) for input in inputs)) % ','.join(str(input) for input in inputs))
def perform(self, node, inputs, (out,)): def perform(self, node, inputs, (out,)):
pass # TODO: in general, we need to re-pack the inputs into a valid index, just like
# subtensor
out[0] = inputs[0].__getitem__(inputs[1:])
#return
#raise NotImplementedError()
def grad(self, inputs, (gz,)): def grad(self, inputs, (gz,)):
x = inputs[0] x = inputs[0]
...@@ -3061,9 +3080,14 @@ class AdvancedIncSubtensor(Op): ...@@ -3061,9 +3080,14 @@ class AdvancedIncSubtensor(Op):
% ','.join(str(input) for input in inputs)) % ','.join(str(input) for input in inputs))
def perform(self, node, inputs, (out,)): def perform(self, node, inputs, (out,)):
pass # TODO: same thing as in AdvancedSubtensor's perform TODO
out[0] = inputs[0].copy()
out[0][inputs[2:]] += inputs[1]
#def grad? #def grad?
# grad on x is grad on output
# grad on y is grad_output[idx_list]
# grad on rest is None
......
...@@ -395,6 +395,8 @@ softmax = Softmax() ...@@ -395,6 +395,8 @@ softmax = Softmax()
@opt.register_specialize @opt.register_specialize
@gof.local_optimizer([softmax]) @gof.local_optimizer([softmax])
def local_softmax_with_bias(node): def local_softmax_with_bias(node):
"""Try to turn softmax(sum_of_stuff) -> softmax_w_bias(matrix, bias)
"""
if node.op == softmax: if node.op == softmax:
x, = node.inputs x, = node.inputs
if x.owner and x.owner.op == tensor.add: if x.owner and x.owner.op == tensor.add:
...@@ -422,15 +424,12 @@ def local_softmax_with_bias(node): ...@@ -422,15 +424,12 @@ def local_softmax_with_bias(node):
vector_sum = tensor.add(*vectors) vector_sum = tensor.add(*vectors)
else: else:
vector_sum = vectors[0] vector_sum = vectors[0]
#backport
#vector_sum = tensor.add(*vectors) if len(vectors)>1 else vectors[0]
if len(non_vectors)>1: if len(non_vectors)>1:
non_vector_sum = tensor.add(*non_vectors) non_vector_sum = tensor.add(*non_vectors)
else: else:
non_vector_sum = non_vectors[0] non_vector_sum = non_vectors[0]
#non_vector_sum = tensor.add(*non_vectors) if len(non_vectors)>1 else non_vectors[0]
try: try:
sm_bias = softmax_with_bias(non_vector_sum, vector_sum) sm_bias = softmax_with_bias(non_vector_sum, vector_sum)
except: except:
...@@ -697,7 +696,9 @@ class CrossentropySoftmax1HotWithBiasDx (gof.Op): ...@@ -697,7 +696,9 @@ class CrossentropySoftmax1HotWithBiasDx (gof.Op):
} }
if (%(dnll)s->dimensions[0] != %(sm)s->dimensions[0]) if (%(dnll)s->dimensions[0] != %(sm)s->dimensions[0])
{ {
PyErr_SetString(PyExc_ValueError, "dnll.shape[0] != sm.shape[0]"); PyErr_Format(PyExc_ValueError, "dnll.shape[0] (%%d) != sm.shape[0] (%%d)",
%(dnll)s->dimensions[0], %(sm)s->dimensions[0]);
//PyErr_SetString(PyExc_ValueError, "dnll.shape[0] != sm.shape[0]");
%(fail)s; %(fail)s;
} }
if (%(dnll)s->dimensions[0] != %(y_idx)s->dimensions[0]) if (%(dnll)s->dimensions[0] != %(y_idx)s->dimensions[0])
...@@ -849,13 +850,15 @@ def crossentropy_to_crossentropy_with_softmax(env): ...@@ -849,13 +850,15 @@ def crossentropy_to_crossentropy_with_softmax(env):
x, = sm.owner.inputs x, = sm.owner.inputs
new_nll, new_sm, new_am = crossentropy_softmax_argmax_1hot_with_bias(x, new_nll, new_sm, new_am = crossentropy_softmax_argmax_1hot_with_bias(x,
tensor.zeros_like(x[0]), one_of_n) tensor.zeros_like(x[0]), one_of_n)
env.replace_all_validate([(nll, new_nll),(sm, new_sm)], reason="Merge") env.replace_all_validate([(nll, new_nll),(sm, new_sm)],
reason="crossentropy_to_crossentropy_with_softmax")
return True return True
if sm.owner and sm.owner.op == softmax_with_bias: if sm.owner and sm.owner.op == softmax_with_bias:
x, b = sm.owner.inputs x, b = sm.owner.inputs
new_nll, new_sm, new_am = crossentropy_softmax_argmax_1hot_with_bias(x, b, new_nll, new_sm, new_am = crossentropy_softmax_argmax_1hot_with_bias(x, b,
one_of_n) one_of_n)
env.replace_all_validate([(nll, new_nll),(sm, new_sm)], reason="Merge") env.replace_all_validate([(nll, new_nll),(sm, new_sm)],
reason="crossentropy_to_crossentropy_with_softmax")
return True return True
return False return False
...@@ -892,6 +895,239 @@ def local_argmax_pushdown(node): ...@@ -892,6 +895,239 @@ def local_argmax_pushdown(node):
return tensor._max_and_argmax(pre_x+tensor.DimShuffle(pre_bias.broadcastable, return tensor._max_and_argmax(pre_x+tensor.DimShuffle(pre_bias.broadcastable,
('x',0))(pre_bias), axis) ('x',0))(pre_bias), axis)
# Utility function used by the two next optimizations
def _check_rows_is_arange_len_labels(rows, labels):
'''Check that 'rows' is the same node as T.arange(labels.shape[0])'''
if rows.owner and isinstance(rows.owner.op, tensor.ARange):
start, stop, step = rows.owner.inputs
#print "SSS", start, stop, step
if getattr(start, 'data', None) != 0: #constants will have data
return False
if getattr(step, 'data', None) != 1: # constant step will have data
return False
if stop.owner and isinstance(stop.owner.op, tensor.Subtensor):
#print "GOT SUBTENSOR"
shape_subtensor = stop.owner
if shape_subtensor.op.idx_list == [0]:
shape_var, = shape_subtensor.inputs
#print "GOT SHAPE VAR", shape_var
if shape_var.owner and shape_var.owner.op == tensor._shape:
return shape_var.owner.inputs[0] is labels
@opt.register_specialize
@gof.local_optimizer([])
def local_advanced_indexing_crossentropy_onehot(node):
log = None
sm = None
# First case: log(softmax(x))[rows, labels]
if isinstance(node.op, tensor.AdvancedSubtensor):
try:
log, rows, labels = node.inputs
except:
pass
if log and log.owner and log.owner.op == tensor.log:
sm = log.owner.inputs[0]
# Second case: log(softmax(x)[rows, labels])
if node.op == tensor.log:
pre_log = node.inputs[0].owner
if pre_log and isinstance(pre_log.op, tensor.AdvancedSubtensor):
try:
sm, rows, labels = pre_log.inputs
except:
pass
if sm is not None and sm.owner and sm.owner.op in (softmax, softmax_with_bias):
sm_w_bias = local_softmax_with_bias.transform(sm.owner)
if sm_w_bias:
assert sm_w_bias[0].owner.op == softmax_with_bias
x_var, b_var = sm_w_bias[0].owner.inputs
else:
x_var = sm.owner.inputs[0]
b_var = tensor.zeros_like(x_var[0])
# Check that rows == arange(labels.shape[0])
if _check_rows_is_arange_len_labels(rows, labels):
if labels.ndim == 1 and x_var.ndim == 2:
return [-crossentropy_softmax_argmax_1hot_with_bias(x_var, b_var, labels)[0]]
@opt.register_specialize
@gof.local_optimizer([softmax_grad])
def local_advanced_indexing_crossentropy_onehot_grad(node):
if not (node.op == softmax_grad):
return
sm = None
try:
out_grad, sm = node.inputs
except:
return
if sm is not None and sm.owner and sm.owner.op == softmax:
x_var = sm.owner.inputs[0]
else:
return
# Two cases are supported:
# 1. AdvancedIncSubtensor(
# zeros_like(softmax(x)),
# -1. / AdvancedSubtensor(softmax(x), arange(y.shape[0]), y),
# arange(y.shape[0]),
# y)
# which arises from the gradient of log(softmax(x)[arange(y.shape[0]), y])
#
# 2. AdvancedIncSubtensor(
# zeros_like(log(softmax(x))),
# -1. like (AdvancedSubtensor(log(softmax(x)), arange(y.shape[0]), y)),
# arange(y.shape[0]),
# y)
# / softmax(x)
# which arises from the gradient of log(softmax(x))[arange(y.shape[0]), y]
#
# In some cases, in case 2., insted of "-1. like (AdvancedSubtensor...)",
# we can have "-1. like ([-1] * AdvancedSubtensor...)". This case will be
# recognized too, but other variants, even with the same shape, might not
# (yet).
# First case.
# After the check for AdvancedIncSubtensor, if anything does not fit with
# the formula above, there's no way to fit it with the the second case,
# so we return immediately.
if out_grad.owner and isinstance(out_grad.owner.op, tensor.AdvancedIncSubtensor):
try:
z, incr, rows, labels = out_grad.owner.inputs
except:
return
# Check that z == zeros_like(softmax(x))
if z.owner and z.owner.op == tensor.fill:
model, value = z.owner.inputs
if not (model is sm and numpy.all(value.data == 0)):
return
#else: OK
else:
return
# Check that incr has the form -1./sm[arange(len(y)), y]
if incr.owner and incr.owner.op == tensor.true_div:
num, denom = incr.owner.inputs
if not numpy.all(num.data == -1):
return
#else: OK
if denom.owner and isinstance(denom.owner.op, tensor.AdvancedSubtensor):
try:
maybe_sm, maybe_rows, maybe_labels = denom.owner.inputs
except:
return
if not (maybe_sm is sm and maybe_rows is rows and maybe_labels is labels):
return
#else: OK
else:
return
else:
return
# Check that rows is arange(labels.shape[0])
if not _check_rows_is_arange_len_labels(rows, labels):
return
# else, arguments of AdvancedIncSubtensor are OK,
# it was really case 1.
# Second case
elif out_grad.owner and out_grad.owner.op == tensor.true_div:
try:
num, denom = out_grad.owner.inputs
except:
return
# Check the numerator (AdvancedIncSubtensor)
if num.owner and isinstance(num.owner.op, tensor.AdvancedIncSubtensor):
try:
z, incr, rows, labels = num.owner.inputs
except:
return
# Check z is zeros_like(log(sm))
if z.owner and z.owner.op == tensor.fill:
model, value = z.owner.inputs
if model.owner and model.owner.op == tensor.log:
if sm is model.owner.inputs[0]:
log_sm = model
else:
return
if not numpy.all(value.data == 0):
return
#else: OK
else:
return
else:
return
# Check incr is (-1.) like log(softmax(x))[arange(len(y)), y]
if incr.owner and incr.owner.op == tensor.fill:
model, value = incr.owner.inputs
adv_subtensor = None
if model.owner and isinstance(model.owner.op, tensor.AdvancedSubtensor):
adv_subtensor = model
else:
if model.owner and isinstance(model.owner.op, tensor.Elemwise):
for input in model.owner.inputs:
if input.owner and isinstance(input.owner.op, tensor.AdvancedSubtensor):
adv_subtensor = input
break
#TODO: try them all, not just the first one
else:
return
if adv_subtensor is not None:
try:
maybe_log_sm, maybe_rows, maybe_labels = adv_subtensor.owner.inputs
except:
return
if not (maybe_log_sm is log_sm and maybe_rows is rows and maybe_labels is labels):
return
#else: OK
if not numpy.all(value.data == -1):
return
else:
return
# Check that rows is arange(labels.shape[0])
if not check_rows_is_arange_len_labels(rows, labels):
return
# else, arguments of AdvancedIncSubtensor are OK
# Check the denominator (sm)
if not denom is sm:
return
# else, numerator and denominator are OK,
# it was really case 2.
else:
return
# Dimension check before substitution
if labels.ndim == 1 and x_var.ndim == 2:
print 'YAY!'
return [crossentropy_softmax_1hot_with_bias_dx(tensor.ones_like(sm[:,0]), sm, labels)]
else:
return
def binary_crossentropy(output, target): def binary_crossentropy(output, target):
......
...@@ -14,7 +14,8 @@ from elemwise import Elemwise, DimShuffle ...@@ -14,7 +14,8 @@ from elemwise import Elemwise, DimShuffle
from theano import scalar from theano import scalar
import basic as T import basic as T
import inplace as I import inplace as I
import numpy as N import numpy
import numpy as N #guys... please don't do this in the library :(
import operator import operator
import itertools import itertools
import sys, os import sys, os
...@@ -62,7 +63,6 @@ def get_constant_value(v): ...@@ -62,7 +63,6 @@ def get_constant_value(v):
return get_constant_value(v.owner.inputs[0]) return get_constant_value(v.owner.inputs[0])
raise TypeError(v) raise TypeError(v)
@gof.optimizer @gof.optimizer
def insert_inplace_optimizer(env): def insert_inplace_optimizer(env):
""" """
...@@ -108,10 +108,12 @@ compile.optdb.register('inplace_opt', insert_inplace_optimizer, 75, 'fast_run', ...@@ -108,10 +108,12 @@ compile.optdb.register('inplace_opt', insert_inplace_optimizer, 75, 'fast_run',
def register_canonicalize(lopt, *tags, **kwargs): def register_canonicalize(lopt, *tags, **kwargs):
name = (kwargs and kwargs.pop('name')) or lopt.__name__ name = (kwargs and kwargs.pop('name')) or lopt.__name__
compile.optdb['canonicalize'].register(name, lopt, 'fast_run', *tags) compile.optdb['canonicalize'].register(name, lopt, 'fast_run', *tags)
return lopt
def register_specialize(lopt, *tags, **kwargs): def register_specialize(lopt, *tags, **kwargs):
name = (kwargs and kwargs.pop('name')) or lopt.__name__ name = (kwargs and kwargs.pop('name')) or lopt.__name__
compile.optdb['specialize'].register(name, lopt, 'fast_run', *tags) compile.optdb['specialize'].register(name, lopt, 'fast_run', *tags)
return lopt
###################### ######################
# DimShuffle lifters # # DimShuffle lifters #
...@@ -876,10 +878,39 @@ register_canonicalize(local_mul_canonizer, name = 'local_mul_canonizer') ...@@ -876,10 +878,39 @@ register_canonicalize(local_mul_canonizer, name = 'local_mul_canonizer')
def local_neg_to_mul(node): def local_neg_to_mul(node):
if node.op == T.neg: if node.op == T.neg:
return [T.mul(-1, node.inputs[0])] return [T.mul(-1, node.inputs[0])]
else:
return False
register_canonicalize(local_neg_to_mul) register_canonicalize(local_neg_to_mul)
@register_specialize
@gof.local_optimizer([])
def local_sum_mul_by_scalar(node):
"""sum(scalar * smth) -> scalar * sum(smth)
"""
# TODO: if the the thing inside the Sum is a division,
# we should get at the numerator....
if isinstance(node.op, T.Sum):
thing_summed, = node.inputs
if thing_summed.owner and thing_summed.owner.op == T.mul:
terms = thing_summed.owner.inputs
scalars = [t.dimshuffle() for t in terms if numpy.all(t.type.broadcastable)]
non_scalars = [t for t in terms if not numpy.all(t.broadcastable)]
if scalars:
if len(scalars) > 1:
if len(non_scalars) > 1:
return [T.mul(T.mul(*scalars), node.op(T.mul(*non_scalars)))]
elif len(non_scalars) == 1:
return [T.mul(T.mul(*scalars), node.op(non_scalars[0]))]
else:
return [T.mul(*scalars)]
else:
if len(non_scalars) > 1:
return [T.mul(scalars[0], node.op(T.mul(*non_scalars)))]
elif len(non_scalars) == 1:
return [T.mul(scalars[0], node.op(non_scalars[0]))]
else:
return [scalars[0]]
if thing_summed.owner and thing_summed.owner.op == T.neg:
return [T.neg(node.op(thing_summed.owner.inputs[0]))]
@gof.local_optimizer([T.mul]) @gof.local_optimizer([T.mul])
def local_mul_to_neg(node): def local_mul_to_neg(node):
if node.op == T.mul and N.all(local_mul_canonizer.get_constant(node.inputs[0]) == -1.0): if node.op == T.mul and N.all(local_mul_canonizer.get_constant(node.inputs[0]) == -1.0):
...@@ -888,6 +919,16 @@ def local_mul_to_neg(node): ...@@ -888,6 +919,16 @@ def local_mul_to_neg(node):
return False return False
register_specialize(local_mul_to_neg) register_specialize(local_mul_to_neg)
@register_specialize
@gof.local_optimizer([T.neg])
def local_neg_neg(node):
# other specializations shouldn't put this in,
# but sometimes they do
if node.op == T.neg:
if node.inputs[0].owner and node.inputs[0].owner.op == T.neg:
return [node.inputs[0].owner.inputs[0]]
@gof.local_optimizer([T.mul]) @gof.local_optimizer([T.mul])
def local_mul_zero(node): def local_mul_zero(node):
"""As part of canonicalization, we replace multiplication by zero with zero. """As part of canonicalization, we replace multiplication by zero with zero.
......
...@@ -317,6 +317,44 @@ def test_asymptotic_32(): ...@@ -317,6 +317,44 @@ def test_asymptotic_32():
assert gxval[0,1] == 0.25 assert gxval[0,1] == 0.25
def test_get_rid_of_advanced_indexing_version_of_xent():
rng = numpy.random.RandomState(utt.fetch_seed())
x_val = rng.randn(3,5)
y_val = numpy.asarray([2,4,1])
x = T.dmatrix('x')
y = T.lvector('y')
expressions_to_test = [
T.sum(-T.log(softmax(x)[T.arange(y.shape[0]), y])),
-T.sum(T.log(softmax(x)[T.arange(y.shape[0]), y])),
-T.sum(T.log(softmax(x))[T.arange(y.shape[0]), y]),
T.sum(-T.log(softmax(x))[T.arange(y.shape[0]), y])]
def assert_optimizer_worked(expr):
f = theano.function([x,y], expr)
for i, node in enumerate(f.maker.env.toposort()):
print i, node
f(x_val, y_val)
assert len(f.maker.env.toposort()) == 4
for expr in expressions_to_test:
assert_optimizer_worked(expr)
## Gradient wrt x
for expr in expressions_to_test:
grad_x = T.grad(expr, x)
g = theano.function([x, y], grad_x)
for i, node in enumerate(g.maker.env.toposort()):
print i, node
g(x_val, y_val)
assert len(g.maker.env.toposort()) == 4
#TODO: Case with bias
# hint - call local_softmax_with_bias from within the other optimization
# hint - call the argmax push-down optimization first too
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论