提交 8edbfbdf authored 作者: Eric Larsen's avatar Eric Larsen 提交者: Frederic

Combine similarly named T_prepend classes; touch up text layout

上级 a5b241e2
...@@ -233,6 +233,7 @@ class T_CrossentropySoftmax1HotWithBiasDx(utt.InferShapeTester): ...@@ -233,6 +233,7 @@ class T_CrossentropySoftmax1HotWithBiasDx(utt.InferShapeTester):
class T_CrossentropySoftmaxArgmax1HotWithBias(utt.InferShapeTester): class T_CrossentropySoftmaxArgmax1HotWithBias(utt.InferShapeTester):
def setUp(self): def setUp(self):
super(T_CrossentropySoftmaxArgmax1HotWithBias, self).setUp() super(T_CrossentropySoftmaxArgmax1HotWithBias, self).setUp()
self.op = theano.tensor.nnet.crossentropy_softmax_argmax_1hot_with_bias self.op = theano.tensor.nnet.crossentropy_softmax_argmax_1hot_with_bias
...@@ -281,21 +282,8 @@ class T_prepend(utt.InferShapeTester): ...@@ -281,21 +282,8 @@ class T_prepend(utt.InferShapeTester):
self.assertTrue(my.shape == (3, 6), my.shape) self.assertTrue(my.shape == (3, 6), my.shape)
self.assertTrue(numpy.all(my[:, 0] == 4.0)) self.assertTrue(numpy.all(my[:, 0] == 4.0))
def test_infer_shape(self): def test1(self):
admat = dmatrix() "basic functionality"
rng = numpy.random.RandomState(utt.fetch_seed())
admat_val = rng.rand(3, 5)
adscal_val = rng.rand()
self._compile_and_check([admat],
[Prepend_scalar_constant_to_each_row(adscal_val)(admat)],
[admat_val],
Prepend_scalar_constant_to_each_row)
class T_prepend(utt.InferShapeTester):
def test0(self):
"""basic functionality"""
x = tensor.matrix('x') x = tensor.matrix('x')
y = Prepend_scalar_to_each_row()(5., x) y = Prepend_scalar_to_each_row()(5., x)
f = theano.function([x], y) f = theano.function([x], y)
...@@ -310,6 +298,11 @@ class T_prepend(utt.InferShapeTester): ...@@ -310,6 +298,11 @@ class T_prepend(utt.InferShapeTester):
rng = numpy.random.RandomState(utt.fetch_seed()) rng = numpy.random.RandomState(utt.fetch_seed())
admat_val = rng.rand(3, 5) admat_val = rng.rand(3, 5)
adscal_val = rng.rand() adscal_val = rng.rand()
self._compile_and_check([admat],
[Prepend_scalar_constant_to_each_row(adscal_val)(admat)],
[admat_val],
Prepend_scalar_constant_to_each_row)
self._compile_and_check([adscal, admat], self._compile_and_check([adscal, admat],
[Prepend_scalar_to_each_row()(adscal, admat)], [Prepend_scalar_to_each_row()(adscal, admat)],
[adscal_val, admat_val], [adscal_val, admat_val],
...@@ -317,6 +310,7 @@ class T_prepend(utt.InferShapeTester): ...@@ -317,6 +310,7 @@ class T_prepend(utt.InferShapeTester):
class T_CrossentropyCategorical1HotGrad(utt.InferShapeTester): class T_CrossentropyCategorical1HotGrad(utt.InferShapeTester):
def test_infer_shape(self): def test_infer_shape(self):
advec = dvector() advec = dvector()
admat = dmatrix() admat = dmatrix()
...@@ -773,7 +767,7 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -773,7 +767,7 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
except Exception: except Exception:
theano.printing.debugprint(f) theano.printing.debugprint(f)
raise raise
g = theano.function([x,y], T.grad(expr, x), mode=mode) g = theano.function([x, y], T.grad(expr, x), mode=mode)
print_graph(g) print_graph(g)
try: try:
ops = [node.op for node in g.maker.fgraph.toposort()] ops = [node.op for node in g.maker.fgraph.toposort()]
...@@ -999,6 +993,7 @@ def test_argmax_pushdown(): ...@@ -999,6 +993,7 @@ def test_argmax_pushdown():
assert isinstance(fgraph.toposort()[2].op.scalar_op, theano.scalar.Maximum) assert isinstance(fgraph.toposort()[2].op.scalar_op, theano.scalar.Maximum)
assert str(fgraph.toposort()[3].op) == 'OutputGuard' assert str(fgraph.toposort()[3].op) == 'OutputGuard'
def test_argmax_pushdown_bias(): def test_argmax_pushdown_bias():
x = tensor.dmatrix() x = tensor.dmatrix()
b = tensor.dvector() b = tensor.dvector()
...@@ -1073,7 +1068,6 @@ def test_asymptotic_32(): ...@@ -1073,7 +1068,6 @@ def test_asymptotic_32():
xval = numpy.zeros((5, 5), dtype=dtype) xval = numpy.zeros((5, 5), dtype=dtype)
x2val = numpy.zeros(5, dtype=xval.dtype) x2val = numpy.zeros(5, dtype=xval.dtype)
for i in xrange(100): for i in xrange(100):
cval, gxval = f(xval, numpy.arange(5), x2val) cval, gxval = f(xval, numpy.arange(5), x2val)
xval -= 100.3 * gxval xval -= 100.3 * gxval
#print cval, gxval #print cval, gxval
...@@ -1099,8 +1093,8 @@ class Test_softmax_opt: ...@@ -1099,8 +1093,8 @@ class Test_softmax_opt:
# divided by row sums are replaced by softmax expressions. # divided by row sums are replaced by softmax expressions.
# #
# Softmax_grad isn't that interesting as an Op, but it has the signature # Softmax_grad isn't that interesting as an Op, but it has the signature
# we look for when trying to insert CrossEntropySoftmax... grad. So for now, # we look for when trying to insert CrossEntropySoftmax... grad. So, for
# we add softmax_grad to graphs. In the future, we may modify the # now, we add softmax_grad to graphs. In the future, we may modify the
# CrossEntropySoftmax...grad to look for the more basic pattern. # CrossEntropySoftmax...grad to look for the more basic pattern.
# #
...@@ -1133,7 +1127,7 @@ class Test_softmax_opt: ...@@ -1133,7 +1127,7 @@ class Test_softmax_opt:
backup = config.warn.sum_div_dimshuffle_bug backup = config.warn.sum_div_dimshuffle_bug
config.warn.sum_div_dimshuffle_bug = False config.warn.sum_div_dimshuffle_bug = False
try: try:
g = theano.function([c, w],T.grad((p_y*w).sum(), c)) g = theano.function([c, w], T.grad((p_y * w).sum(), c))
finally: finally:
config.warn.sum_div_dimshuffle_bug = backup config.warn.sum_div_dimshuffle_bug = backup
g_ops = [n.op for n in g.maker.fgraph.toposort()] g_ops = [n.op for n in g.maker.fgraph.toposort()]
...@@ -1145,7 +1139,7 @@ class Test_softmax_opt: ...@@ -1145,7 +1139,7 @@ class Test_softmax_opt:
assert len(g_ops) == 2 assert len(g_ops) == 2
assert softmax in g_ops assert softmax in g_ops
assert softmax_grad in g_ops assert softmax_grad in g_ops
g(self.rng.rand(3, 4), self.rng.uniform(.5, 1, (3,4))) g(self.rng.rand(3, 4), self.rng.uniform(.5, 1, (3, 4)))
def test_transpose_basic(self): def test_transpose_basic(self):
# this should be a transposed softmax # this should be a transposed softmax
...@@ -1185,7 +1179,10 @@ class Test_softmax_opt: ...@@ -1185,7 +1179,10 @@ class Test_softmax_opt:
#printing.debugprint(g) #printing.debugprint(g)
raise SkipTest('Optimization not enabled for the moment') raise SkipTest('Optimization not enabled for the moment')
# REPEAT 3 CASES in presence of log(softmax) with the advanced indexing etc. # REPEAT 3 CASES in presence of log(softmax) with the advanced indexing
# etc.
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论