提交 ae3c3c7b authored 作者: abergeron's avatar abergeron

Merge pull request #1701 from nouiz/local_greedy_distributor

Local greedy distributor, make sure it won't be skipped again.
......@@ -164,6 +164,8 @@ def conv3d(signals, filters,
border_mode='valid'):
"""Convolve spatio-temporal filters with a movie.
It flip the filters.
:param signals: timeseries of images whose pixels have color channels.
shape: [Ns, Ts, C, Hs, Ws]
:param filters: spatio-temporal filters
......@@ -173,6 +175,9 @@ def conv3d(signals, filters,
:param border_mode: The only one tested is 'valid'.
:note: Work on the GPU.
Another way to define signals: (batch, time, in channel, row, column)
Another way to define filters: (out channel,time,in channel, row, column)
"""
if isinstance(border_mode, str):
......
......@@ -3936,7 +3936,7 @@ def attempt_distribution(factor, num, denum):
neg_pairs))), num, denum
@gof.local_optimizer([T.mul, T.true_div])
@gof.local_optimizer([T.mul, T.true_div, T.inv])
def local_greedy_distributor(node):
"""
This optimization tries to apply distributivity of multiplication
......
......@@ -19,9 +19,8 @@ from theano.tensor.blas import (_dot22, _dot22scalar, res_is_a, _as_scalar,
gemm_inplace, gemm_no_inplace,
InconsistencyError, Ger, ger, ger_destructive)
from theano.tests import unittest_tools
from test_basic import (_approx_eq, as_tensor_variable, inplace_func,
compile, inplace)
#, constant, eval_outputs)
from test_basic import (as_tensor_variable, inplace_func,
compile, inplace)
import theano.tensor.blas_scipy
......@@ -50,7 +49,6 @@ class t_gemm(TestCase):
"""
def setUp(self):
unittest_tools.seed_rng()
_approx_eq.debug = 0
Gemm.debug = False
@staticmethod
......@@ -84,8 +82,7 @@ class t_gemm(TestCase):
z_after = self._gemm(z_orig, a, x, y, b)
#print z_orig, z_after, z, type(z_orig), type(z_after), type(z)
#_approx_eq.debug = 1
self.assertTrue(_approx_eq(z_after, z))
unittest_tools.assert_allclose(z_after, z)
if a == 0.0 and b == 1.0:
return
elif z_orig.size == 0:
......@@ -150,7 +147,6 @@ class t_gemm(TestCase):
self.rand(3, 5), self.rand(5, 4), -1.0)
def test10(self):
_approx_eq.debug = 1
self.cmp(self.rand(3, 4), -1.0, self.rand(3, 5), self.rand(5, 4), 0.0)
def test11(self):
......@@ -281,14 +277,11 @@ class t_gemm(TestCase):
f = inplace_func([], gemm_inplace(tz, ta, tx, ty, tb),
mode=compile.Mode(optimizer=None, linker=l))
f()
self.assertTrue(_approx_eq(z_after, tz.get_value(borrow=True)),
(z_orig, z_after, z, z_after - z))
unittest_tools.assert_allclose(z_after, tz.get_value(borrow=True))
f()
self.assertTrue(_approx_eq(z_after, tz.get_value(borrow=True)),
(z_orig, z_after, z, z_after - z))
unittest_tools.assert_allclose(z_after, tz.get_value(borrow=True))
f()
self.assertTrue(_approx_eq(z_after, tz.get_value(borrow=True)),
(z_orig, z_after, z, z_after - z))
unittest_tools.assert_allclose(z_after, tz.get_value(borrow=True))
#tz.value *= 0 # clear z's value
y_T = ty.get_value(borrow=True).T
......@@ -298,7 +291,7 @@ class t_gemm(TestCase):
f()
# test that the transposed version of multiplication gives
# same answer
self.assertTrue(_approx_eq(z_after, tz.get_value(borrow=True).T))
unittest_tools.assert_allclose(z_after, tz.get_value(borrow=True).T)
t(C, A, B)
t(C.T, A, B)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论