提交 ae3c3c7b authored 作者: abergeron's avatar abergeron

Merge pull request #1701 from nouiz/local_greedy_distributor

Local greedy distributor, make sure it won't be skipped again.
...@@ -164,6 +164,8 @@ def conv3d(signals, filters, ...@@ -164,6 +164,8 @@ def conv3d(signals, filters,
border_mode='valid'): border_mode='valid'):
"""Convolve spatio-temporal filters with a movie. """Convolve spatio-temporal filters with a movie.
It flip the filters.
:param signals: timeseries of images whose pixels have color channels. :param signals: timeseries of images whose pixels have color channels.
shape: [Ns, Ts, C, Hs, Ws] shape: [Ns, Ts, C, Hs, Ws]
:param filters: spatio-temporal filters :param filters: spatio-temporal filters
...@@ -173,6 +175,9 @@ def conv3d(signals, filters, ...@@ -173,6 +175,9 @@ def conv3d(signals, filters,
:param border_mode: The only one tested is 'valid'. :param border_mode: The only one tested is 'valid'.
:note: Work on the GPU. :note: Work on the GPU.
Another way to define signals: (batch, time, in channel, row, column)
Another way to define filters: (out channel,time,in channel, row, column)
""" """
if isinstance(border_mode, str): if isinstance(border_mode, str):
......
...@@ -3936,7 +3936,7 @@ def attempt_distribution(factor, num, denum): ...@@ -3936,7 +3936,7 @@ def attempt_distribution(factor, num, denum):
neg_pairs))), num, denum neg_pairs))), num, denum
@gof.local_optimizer([T.mul, T.true_div]) @gof.local_optimizer([T.mul, T.true_div, T.inv])
def local_greedy_distributor(node): def local_greedy_distributor(node):
""" """
This optimization tries to apply distributivity of multiplication This optimization tries to apply distributivity of multiplication
......
...@@ -19,9 +19,8 @@ from theano.tensor.blas import (_dot22, _dot22scalar, res_is_a, _as_scalar, ...@@ -19,9 +19,8 @@ from theano.tensor.blas import (_dot22, _dot22scalar, res_is_a, _as_scalar,
gemm_inplace, gemm_no_inplace, gemm_inplace, gemm_no_inplace,
InconsistencyError, Ger, ger, ger_destructive) InconsistencyError, Ger, ger, ger_destructive)
from theano.tests import unittest_tools from theano.tests import unittest_tools
from test_basic import (_approx_eq, as_tensor_variable, inplace_func, from test_basic import (as_tensor_variable, inplace_func,
compile, inplace) compile, inplace)
#, constant, eval_outputs)
import theano.tensor.blas_scipy import theano.tensor.blas_scipy
...@@ -50,7 +49,6 @@ class t_gemm(TestCase): ...@@ -50,7 +49,6 @@ class t_gemm(TestCase):
""" """
def setUp(self): def setUp(self):
unittest_tools.seed_rng() unittest_tools.seed_rng()
_approx_eq.debug = 0
Gemm.debug = False Gemm.debug = False
@staticmethod @staticmethod
...@@ -84,8 +82,7 @@ class t_gemm(TestCase): ...@@ -84,8 +82,7 @@ class t_gemm(TestCase):
z_after = self._gemm(z_orig, a, x, y, b) z_after = self._gemm(z_orig, a, x, y, b)
#print z_orig, z_after, z, type(z_orig), type(z_after), type(z) #print z_orig, z_after, z, type(z_orig), type(z_after), type(z)
#_approx_eq.debug = 1 unittest_tools.assert_allclose(z_after, z)
self.assertTrue(_approx_eq(z_after, z))
if a == 0.0 and b == 1.0: if a == 0.0 and b == 1.0:
return return
elif z_orig.size == 0: elif z_orig.size == 0:
...@@ -150,7 +147,6 @@ class t_gemm(TestCase): ...@@ -150,7 +147,6 @@ class t_gemm(TestCase):
self.rand(3, 5), self.rand(5, 4), -1.0) self.rand(3, 5), self.rand(5, 4), -1.0)
def test10(self): def test10(self):
_approx_eq.debug = 1
self.cmp(self.rand(3, 4), -1.0, self.rand(3, 5), self.rand(5, 4), 0.0) self.cmp(self.rand(3, 4), -1.0, self.rand(3, 5), self.rand(5, 4), 0.0)
def test11(self): def test11(self):
...@@ -281,14 +277,11 @@ class t_gemm(TestCase): ...@@ -281,14 +277,11 @@ class t_gemm(TestCase):
f = inplace_func([], gemm_inplace(tz, ta, tx, ty, tb), f = inplace_func([], gemm_inplace(tz, ta, tx, ty, tb),
mode=compile.Mode(optimizer=None, linker=l)) mode=compile.Mode(optimizer=None, linker=l))
f() f()
self.assertTrue(_approx_eq(z_after, tz.get_value(borrow=True)), unittest_tools.assert_allclose(z_after, tz.get_value(borrow=True))
(z_orig, z_after, z, z_after - z))
f() f()
self.assertTrue(_approx_eq(z_after, tz.get_value(borrow=True)), unittest_tools.assert_allclose(z_after, tz.get_value(borrow=True))
(z_orig, z_after, z, z_after - z))
f() f()
self.assertTrue(_approx_eq(z_after, tz.get_value(borrow=True)), unittest_tools.assert_allclose(z_after, tz.get_value(borrow=True))
(z_orig, z_after, z, z_after - z))
#tz.value *= 0 # clear z's value #tz.value *= 0 # clear z's value
y_T = ty.get_value(borrow=True).T y_T = ty.get_value(borrow=True).T
...@@ -298,7 +291,7 @@ class t_gemm(TestCase): ...@@ -298,7 +291,7 @@ class t_gemm(TestCase):
f() f()
# test that the transposed version of multiplication gives # test that the transposed version of multiplication gives
# same answer # same answer
self.assertTrue(_approx_eq(z_after, tz.get_value(borrow=True).T)) unittest_tools.assert_allclose(z_after, tz.get_value(borrow=True).T)
t(C, A, B) t(C, A, B)
t(C.T, A, B) t(C.T, A, B)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论