提交 087b430b authored 作者: bergstra@ip05.m's avatar bergstra@ip05.m

scipy.signal.convolve2d causes errors on atexit. i moved imports to make it…

scipy.signal.convolve2d causes errors on atexit. i moved imports to make it possible to use our convolve op without triggering this problem
上级 eee69c5a
...@@ -2,8 +2,6 @@ import numpy as N ...@@ -2,8 +2,6 @@ import numpy as N
import theano import theano
import theano.tensor as T import theano.tensor as T
from theano import gof, Op, tensor from theano import gof, Op, tensor
from scipy.signal.signaltools import _valfrommode, _bvalfromboundary
from scipy.signal.sigtools import _convolve2d
from theano.printing import Print from theano.printing import Print
def getFilterOutShp(inshp, kshp, (dx,dy)=(1,1), mode='valid'): def getFilterOutShp(inshp, kshp, (dx,dy)=(1,1), mode='valid'):
...@@ -60,6 +58,9 @@ class ConvOp(Op): ...@@ -60,6 +58,9 @@ class ConvOp(Op):
""" """
By default if len(img2d.shape)==3, we By default if len(img2d.shape)==3, we
""" """
# TODO: move these back out to global scope when they no longer cause an atexit error
from scipy.signal.signaltools import _valfrommode, _bvalfromboundary
from scipy.signal.sigtools import _convolve2d
if z[0] is None: if z[0] is None:
z[0] = N.zeros((self.bsize,)+(self.nkern,)+tuple(self.outshp)) z[0] = N.zeros((self.bsize,)+(self.nkern,)+tuple(self.outshp))
zz=z[0] zz=z[0]
......
...@@ -3,9 +3,6 @@ import sys, time, unittest ...@@ -3,9 +3,6 @@ import sys, time, unittest
import numpy import numpy
import numpy as N import numpy as N
from scipy.signal import convolve2d
from scipy.signal.sigtools import _convolve2d
from scipy.signal.signaltools import _valfrommode, _bvalfromboundary
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from theano import function, Mode from theano import function, Mode
...@@ -43,6 +40,7 @@ class TestConvOp(unittest.TestCase): ...@@ -43,6 +40,7 @@ class TestConvOp(unittest.TestCase):
print '\n\n*************************************************' print '\n\n*************************************************'
print ' TEST CONVOLUTION' print ' TEST CONVOLUTION'
print '*************************************************' print '*************************************************'
from scipy.signal import convolve2d
if 0: if 0:
# fixed parameters # fixed parameters
...@@ -189,6 +187,9 @@ class TestConvOp(unittest.TestCase): ...@@ -189,6 +187,9 @@ class TestConvOp(unittest.TestCase):
print 'speed up ConvOp vs convolve2d: %.3f'%d.mean(),d print 'speed up ConvOp vs convolve2d: %.3f'%d.mean(),d
def test_multilayer_conv(self): def test_multilayer_conv(self):
# causes an atexit problem
from scipy.signal.sigtools import _convolve2d
from scipy.signal.signaltools import _valfrommode, _bvalfromboundary
# fixed parameters # fixed parameters
bsize = 1 # batch size bsize = 1 # batch size
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论