提交 89718faa authored 作者: Frederic's avatar Frederic

Fix gh-788. Now after the other opt change, fixing this got very easy.

上级 81cfecf3
......@@ -1248,16 +1248,8 @@ class CrossentropyCategorical1Hot(gof.Op):
y[i] = -numpy.log(coding[i, one_of_n[i]])
y_out[0] = y
# Enabling this infer_shape method make 2 tests fail:
# theano/tensor/nnet/tests/test_nnet.py:T_CrossentropyCategorical1Hot.
# {test_softmax_grad_optimizations,test_softmax_grad_optimizations_vector}
# This is caused by the local_fill_to_alloc that call broadcast_like
# that look into the shape feature and return a Rebroadcast instead of an alloc.
# I disable this infer_shape until we fix the optimizations or determine that
# this is not needed anymore and we update the tests.
# see issue gh-788
# def infer_shape(self, node, in_shapes):
# return [(in_shapes[0][0],)]
def infer_shape(self, node, in_shapes):
return [(in_shapes[0][0],)]
def grad(self, inp, grads):
coding, one_of_n = inp
......
......@@ -380,8 +380,7 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
tensor.verify_grad(oplike, [x_val], rng=numpy.random)
# see issue gh-788
def est_infer_shape(self):
def test_infer_shape(self):
admat = matrix()
alvec = lvector()
rng = numpy.random.RandomState(utt.fetch_seed())
......@@ -535,8 +534,6 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
# for node in fgraph.toposort():
# print node.op, node.inputs
# the function has 9 ops because the dimshuffle and lemwise{second}
# aren't getting cleaned up as well as we'd like.
has_cx1hot = False
has_cx1hotdx = False
has_softmax = False
......@@ -550,9 +547,9 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
has_softmax = True
if node.op == softmax_grad:
has_softmaxdx = True
assert has_cx1hot
assert not has_cx1hot
assert has_cx1hotdx
assert not has_softmax
assert has_softmax
assert not has_softmaxdx
def test_softmax_grad_optimizations_vector(self):
......@@ -577,8 +574,6 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
# for node in fgraph.toposort():
# print node.op, node.inputs
# the function has 9 ops because the dimshuffle and elemwise{second}
# aren't getting cleaned up as well as we'd like.
has_cx1hot = False
has_cx1hotdx = False
has_softmax = False
......@@ -592,9 +587,9 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
has_softmax = True
if node.op == softmax_grad:
has_softmaxdx = True
assert has_cx1hot
assert not has_cx1hot
assert has_cx1hotdx
assert not has_softmax
assert has_softmax
assert not has_softmaxdx
def test_get_rid_of_advanced_indexing_version_of_xent(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论