提交 2e7790ac authored 作者: James Bergstra's avatar James Bergstra

fixed syntax errs in log1p

上级 c82a3f87
......@@ -193,6 +193,8 @@ class Mode(object):
def including(self, *tags):
link, opt = self.get_linker_optimizer(self.provided_linker, self.provided_optimizer)
#N.B. opt might be a Query instance, not sure what else it might be...
# string? Optimizer? OptDB? who knows???
return self.__class__(linker=link, optimizer=opt.including(*tags))
def excluding(self, *tags):
......
......@@ -905,36 +905,37 @@ class test_fusion(unittest.TestCase):
#g.owner.inputs[0] is out... make owner a weakref?
def test_log1p():
m = theano.compile.FAST_RUN
# check some basic cases
x = dvector()
f = function([x], T.log(1+(x)), mode='FAST_RUN')
f = function([x], T.log(1+(x)), mode=m)
assert [node.op for node in f.maker.env.toposort()] == [T.log1p]
f = (function([x], T.log(1+(-x))), mode='FAST_RUN')
f = function([x], T.log(1+(-x)), mode=m)
assert [node.op for node in f.maker.env.toposort()] == [T.neg, inplace.log1p_inplace]
f = (function([x], -T.log(1+(-x))), mode='FAST_RUN')
f = function([x], -T.log(1+(-x)), mode=m)
assert [node.op for node in f.maker.env.toposort()] == [T.neg, inplace.log1p_inplace, inplace.neg_inplace]
# check trickier cases (and use different dtype)
y = fmatrix()
f = (function([x,y], T.log(fill(y,1)+(x))), mode='FAST_RUN')
f = function([x,y], T.log(fill(y,1)+(x)), mode=m)
assert [node.op for node in f.maker.env.toposort()] == [T.DimShuffle([False], ['x', 0], True), T.log1p, T.fill]
f = (function([x,y], T.log(0+(x) + fill(y,1.0) )), mode='FAST_RUN')
f = function([x,y], T.log(0+(x) + fill(y,1.0)), mode=m)
assert [node.op for node in f.maker.env.toposort()] == [T.DimShuffle([False], ['x', 0], True), T.log1p, T.fill]
f = (function([x,y], T.log(2+(x) - fill(y,1.0) )), mode='FAST_RUN')
f = function([x,y], T.log(2+(x) - fill(y,1.0)), mode=m)
assert [node.op for node in f.maker.env.toposort()] == [T.DimShuffle([False], ['x', 0], True), T.log1p, T.fill]
f([1e-7, 10], [[0, 0], [0, 0]]) #debugmode will verify values
# should work for complex
z = zmatrix()
f = function([z], T.log(1+(z)), mode='FAST_RUN')
f = function([z], T.log(1+(z)), mode=m)
assert [node.op for node in f.maker.env.toposort()] == [T.log1p]
# should work for int
z = imatrix()
f = function([z], T.log(1+(z)), mode='FAST_RUN')
f = function([z], T.log(1+(z)), mode=m)
assert [node.op for node in f.maker.env.toposort()] == [T.log1p]
if __name__ == '__main__':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论