提交 aa47b0ab authored 作者: James Bergstra's avatar James Bergstra

fix #162

上级 02a2b749
...@@ -1058,6 +1058,16 @@ class T_add(unittest.TestCase): ...@@ -1058,6 +1058,16 @@ class T_add(unittest.TestCase):
def test_grad_col(self): def test_grad_col(self):
verify_grad(self, add, [numpy.random.rand(3, 5), numpy.random.rand(3, 1)]) verify_grad(self, add, [numpy.random.rand(3, 5), numpy.random.rand(3, 1)])
class T_exp(unittest.TestCase):
def test_grad_0(self):
verify_grad(self, exp, [
numpy.asarray([[ 1.5089518 , 1.48439076, -4.7820262 ],
[ 2.04832468, 0.50791564, -1.58892269]])])
def test_grad_1(self):
verify_grad(self, exp_inplace, [
numpy.asarray([[ 1.5089518 , 1.48439076, -4.7820262 ],
[ 2.04832468, 0.50791564, -1.58892269]])])
# class T_abs(unittest.TestCase): # class T_abs(unittest.TestCase):
# def test_impl(self): # def test_impl(self):
...@@ -1711,7 +1721,19 @@ class _test_grad(unittest.TestCase): ...@@ -1711,7 +1721,19 @@ class _test_grad(unittest.TestCase):
self.failUnless(isinstance(g2, TensorConstant)) self.failUnless(isinstance(g2, TensorConstant))
self.failUnless(g2.data == 0) self.failUnless(g2.data == 0)
class T_op_cache(unittest.TestCase):
def test0(self):
"""trigger bug in ticket #162"""
lr = constant(0.011)
v = matrix()
v.name = 'v'
gv = fill(v/v, 1.0)/v - (fill(v/v, 1.0) * v) / (v*v)
fn_py = function([v], [gv], linker = 'py')
fn_c_or_py = function([v], [gv], linker = 'c|py')
a = numpy.random.rand(5,2)
self.failUnless(numpy.all(fn_py(a) == fn_c_or_py(a)))
if __name__ == '__main__': if __name__ == '__main__':
if 1: if 1:
...@@ -1723,3 +1745,4 @@ if __name__ == '__main__': ...@@ -1723,3 +1745,4 @@ if __name__ == '__main__':
suite = suite.loadTestsFromTestCase(testcase) suite = suite.loadTestsFromTestCase(testcase)
unittest.TextTestRunner(verbosity=2).run(suite) unittest.TextTestRunner(verbosity=2).run(suite)
...@@ -816,12 +816,13 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -816,12 +816,13 @@ class OpWiseCLinker(link.LocalLinker):
desc = (node.op, desc = (node.op,
tuple(input.type for input in node.inputs), tuple(input.type for input in node.inputs),
tuple(input.type for input in node.inputs), tuple(input.type for input in node.inputs),
tuple(output in no_recycling for output in node.outputs)) tuple(output in no_recycling for output in node.outputs),
tuple(node.inputs.count(input) for input in node.inputs))
try: try:
cl = self.__cache__.get(desc) cl = self.__cache__.get(desc)
except Exception, exc: except Exception, exc:
print "harmless warning: failed to hash %s: %s" % (node, exc) print >> sys.stderr, "INFO: failed to hash %s: %s. Node will not be cached." % (node, exc)
cl = None cl = None
if cl is None: if cl is None:
cl = CLinker().accept(e, [r for r, r2 in zip(e.outputs, node.outputs) if r2 in no_recycling]) cl = CLinker().accept(e, [r for r, r2 in zip(e.outputs, node.outputs) if r2 in no_recycling])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论