提交 382d2ed1 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #1571 from lamblin/test_bug_cho

Add test for gh-1549
"""
Utility classes and methods to pickle parts of symbolic graph.
These pickled graphs can be used, for instance, as cases for
unit tests or regression tests.
"""
__docformat__ = "restructuredtext en"
__authors__ = "Pascal Lamblin"
__copyright__ = "Copyright 2013, Universite de Montreal"
__license__ = "3-clause BSD"
import pickle
import sys
import theano
sys.setrecursionlimit(3000)
Pickler = pickle.Pickler
class StripPickler(Pickler):
"""
Subclass of Pickler that strips unnecessary attributes from Theano objects.
Example of use::
fn_args = dict(inputs=inputs,
outputs=outputs,
updates=updates)
dest_pkl = 'my_test.pkl'
f = open(dest_pkl, 'wb')
strip_pickler = StripPickler(f, protocol=-1)
strip_pickler.dump(fn_args)
f.close()
"""
def save(self, obj):
# Remove the tag.trace attribute from Variable and Apply nodes
if isinstance(obj, theano.gof.utils.scratchpad):
if hasattr(obj, 'trace'):
del obj.trace
# Remove manually-added docstring of Elemwise ops
elif (isinstance(obj, theano.tensor.Elemwise)):
if '__doc__' in obj.__dict__:
del obj.__dict__['__doc__']
return Pickler.save(self, obj)
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
import copy import copy
import logging import logging
import pickle
import os
import time import time
import unittest import unittest
...@@ -2645,6 +2647,18 @@ class test_shapeoptimizer(unittest.TestCase): ...@@ -2645,6 +2647,18 @@ class test_shapeoptimizer(unittest.TestCase):
f = theano.function([X], expr, mode=mode) f = theano.function([X], expr, mode=mode)
print f([[1, 2], [2, 3]]) print f([[1, 2], [2, 3]])
def test_no_cycle(self):
# Optimizing this graph resulted in a cycle, see gh-1549
# This test depends on cuda
import theano.sandbox.cuda as cuda
if not cuda.cuda_available:
raise SkipTest("cuda not available")
pkl_filename = os.path.join(os.path.dirname(theano.__file__),
'tensor', 'tests', 'shape_opt_cycle.pkl')
fn_args = pickle.load(open(pkl_filename, "rb"))
theano.function(**fn_args)
class test_assert(utt.InferShapeTester): class test_assert(utt.InferShapeTester):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论