提交 490ef97a authored 作者: lamblin's avatar lamblin

Merge pull request #710 from nouiz/small

Small
...@@ -693,7 +693,6 @@ class T_Scan(unittest.TestCase): ...@@ -693,7 +693,6 @@ class T_Scan(unittest.TestCase):
outputs_info = [None]) outputs_info = [None])
inp = numpy.arange(5).astype('float64') inp = numpy.arange(5).astype('float64')
rval = theano.function([x], y, updates=updates)(inp) rval = theano.function([x], y, updates=updates)(inp)
import ipdb; ipdb.set_trace()
assert numpy.all(rval == inp[:-1]) assert numpy.all(rval == inp[:-1])
# simple rnn, one input, one state, weights for each; input/state are # simple rnn, one input, one state, weights for each; input/state are
......
...@@ -708,7 +708,7 @@ class ConvOp(Op): ...@@ -708,7 +708,7 @@ class ConvOp(Op):
raise NotImplementedError('todo') raise NotImplementedError('todo')
if self.dx not in (1, 2) or self.dy not in (1, 2): if self.dx not in (1, 2) or self.dy not in (1, 2):
raise Exception("ERROR: We disable ConvOp.grad now when dx or "\ raise NotImplementedError("ERROR: We disable ConvOp.grad now when dx or "\
"dy are different from 1 and 2, as there is a bug in it.") "dy are different from 1 and 2, as there is a bug in it.")
all_shape = self.imshp is not None and self.kshp is not None and \ all_shape = self.imshp is not None and self.kshp is not None and \
......
import sys, time, unittest import sys
import time
import unittest
import numpy import numpy
import theano import theano
...@@ -10,6 +12,7 @@ from theano.tensor.nnet import conv ...@@ -10,6 +12,7 @@ from theano.tensor.nnet import conv
from theano.tensor.basic import _allclose from theano.tensor.basic import _allclose
class TestConv2D(unittest.TestCase): class TestConv2D(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -18,16 +21,18 @@ class TestConv2D(unittest.TestCase): ...@@ -18,16 +21,18 @@ class TestConv2D(unittest.TestCase):
self.filters = T.dtensor4('filters') self.filters = T.dtensor4('filters')
def validate(self, image_shape, filter_shape, def validate(self, image_shape, filter_shape,
border_mode='valid', subsample=(1,1), border_mode='valid', subsample=(1, 1),
N_image_shape=None, N_filter_shape=None, N_image_shape=None, N_filter_shape=None,
input=None, filters=None, input=None, filters=None,
unroll_batch=None, unroll_kern=None, unroll_patch=None, unroll_batch=None, unroll_kern=None, unroll_patch=None,
verify_grad=True, should_raise=False): verify_grad=True, should_raise=False):
if N_image_shape is None: if N_image_shape is None:
N_image_shape = [T.get_constant_value(T.as_tensor_variable(x)) for x in image_shape] N_image_shape = [T.get_constant_value(T.
as_tensor_variable(x)) for x in image_shape]
if N_filter_shape is None: if N_filter_shape is None:
N_filter_shape = [T.get_constant_value(T.as_tensor_variable(x)) for x in filter_shape] N_filter_shape = [T.get_constant_value(T.
as_tensor_variable(x)) for x in filter_shape]
if not input: if not input:
input = self.input input = self.input
...@@ -47,48 +52,53 @@ class TestConv2D(unittest.TestCase): ...@@ -47,48 +52,53 @@ class TestConv2D(unittest.TestCase):
theano_conv = theano.function([input, filters], output) theano_conv = theano.function([input, filters], output)
# initialize input and compute result # initialize input and compute result
image_data = numpy.random.random(N_image_shape) image_data = numpy.random.random(N_image_shape)
filter_data = numpy.random.random(N_filter_shape) filter_data = numpy.random.random(N_filter_shape)
try: try:
theano_output = theano_conv(image_data, filter_data) theano_output = theano_conv(image_data, filter_data)
except ValueError: except ValueError:
if not should_raise: raise if not should_raise:
raise
return return
else: else:
if should_raise: raise Exception("ConvOp should have generated an error") if should_raise:
raise Exception(
"ConvOp should have generated an error")
############# REFERENCE IMPLEMENTATION ############ ############# REFERENCE IMPLEMENTATION ############
s = 1. s = 1.
orig_image_data = image_data orig_image_data = image_data
if border_mode is not 'full': s = -1. if border_mode is not 'full':
s = -1.
out_shape2d = numpy.array(N_image_shape[-2:]) +\ out_shape2d = numpy.array(N_image_shape[-2:]) +\
s*numpy.array(N_filter_shape[-2:]) - s s * numpy.array(N_filter_shape[-2:]) - s
out_shape2d = numpy.ceil(out_shape2d / numpy.array(subsample)) out_shape2d = numpy.ceil(out_shape2d / numpy.array(subsample))
out_shape = (N_image_shape[0],N_filter_shape[0]) + tuple(out_shape2d) out_shape = (N_image_shape[0], N_filter_shape[0]) + tuple(out_shape2d)
ref_output = numpy.zeros(out_shape) ref_output = numpy.zeros(out_shape)
# loop over output feature maps # loop over output feature maps
ref_output.fill(0) ref_output.fill(0)
if border_mode=='full': if border_mode == 'full':
image_data2 = numpy.zeros((N_image_shape[0],N_image_shape[1], image_data2 = numpy.zeros((N_image_shape[0], N_image_shape[1],
N_image_shape[2]+2*N_filter_shape[2]-2, N_image_shape[2] + 2 * N_filter_shape[2] - 2,
N_image_shape[3]+2*N_filter_shape[3]-2)) N_image_shape[3] + 2 * N_filter_shape[3] - 2))
image_data2[:,:,N_filter_shape[2]-1:N_filter_shape[2]-1+N_image_shape[2], image_data2[:, :, N_filter_shape[2] - 1:N_filter_shape[2] - 1 + N_image_shape[2],
N_filter_shape[3]-1:N_filter_shape[3]-1+N_image_shape[3]] = image_data N_filter_shape[3] - 1:N_filter_shape[3] - 1 + N_image_shape[3]] = image_data
image_data = image_data2 image_data = image_data2
N_image_shape = image_data.shape N_image_shape = image_data.shape
for bb in range(N_image_shape[0]): for bb in range(N_image_shape[0]):
for nn in range(N_filter_shape[0]): for nn in range(N_filter_shape[0]):
for im0 in range(N_image_shape[1]): for im0 in range(N_image_shape[1]):
filter2d = filter_data[nn,im0,:,:] filter2d = filter_data[nn, im0, :, :]
image2d = image_data[bb,im0,:,:] image2d = image_data[bb, im0, :, :]
for row in range(ref_output.shape[2]): for row in range(ref_output.shape[2]):
irow = row * subsample[0]#image row irow = row * subsample[0] # image row
for col in range(ref_output.shape[3]): for col in range(ref_output.shape[3]):
icol = col * subsample[1]#image col icol = col * subsample[1] # image col
ref_output[bb,nn,row,col] += (image2d[irow:irow+N_filter_shape[2], ref_output[bb, nn, row, col] += (image2d[
icol:icol+N_filter_shape[3]]*filter2d[::-1,::-1] irow:irow + N_filter_shape[2],
).sum() icol:icol + N_filter_shape[3]] * filter2d[::-1,::-1]
).sum()
self.assertTrue(_allclose(theano_output, ref_output)) self.assertTrue(_allclose(theano_output, ref_output))
...@@ -96,135 +106,162 @@ class TestConv2D(unittest.TestCase): ...@@ -96,135 +106,162 @@ class TestConv2D(unittest.TestCase):
if verify_grad: if verify_grad:
utt.verify_grad(sym_conv2d, [orig_image_data, filter_data]) utt.verify_grad(sym_conv2d, [orig_image_data, filter_data])
def test_basic1(self): def test_basic1(self):
"""Tests that basic convolutions work for odd and even
dimensions of image and filter shapes, as well as rectangular
images and filters.
""" """
Tests that basic convolutions work for odd and even dimensions of image and filter self.validate((2, 2, 3, 3), (2, 2, 2, 2), 'valid', verify_grad=False)
shapes, as well as rectangular images and filters.
"""
self.validate((2,2,3,3), (2,2,2,2), 'valid', verify_grad=False)
def test_basic(self): def test_basic(self):
"""Tests that basic convolutions work for odd and even
dimensions of image and filter shapes, as well as rectangular
images and filters.
""" """
Tests that basic convolutions work for odd and even dimensions of image and filter self.validate((3, 2, 8, 8), (4, 2, 5, 5), 'valid', verify_grad=False)
shapes, as well as rectangular images and filters. self.validate((3, 2, 7, 5), (5, 2, 2, 3), 'valid')
""" self.validate((3, 2, 7, 5), (5, 2, 3, 2), 'valid', verify_grad=False)
self.validate((3,2,8,8), (4,2,5,5), 'valid', verify_grad=False) self.validate((3, 2, 8, 8), (4, 2, 5, 5), 'full', verify_grad=False)
self.validate((3,2,7,5), (5,2,2,3), 'valid') self.validate((3, 2, 7, 5), (5, 2, 2, 3), 'full')
self.validate((3,2,7,5), (5,2,3,2), 'valid', verify_grad=False)
self.validate((3,2,8,8), (4,2,5,5), 'full', verify_grad=False)
self.validate((3,2,7,5), (5,2,2,3), 'full')
# test filter same size as input # test filter same size as input
def test_img_kernel_same_shape(self): def test_img_kernel_same_shape(self):
self.validate((3,2,3,3), (4,2,3,3), 'full') self.validate((3, 2, 3, 3), (4, 2, 3, 3), 'full')
self.validate((3,2,3,3), (4,2,3,3), 'valid') self.validate((3, 2, 3, 3), (4, 2, 3, 3), 'valid')
def test_unroll_patch_true(self): def test_unroll_patch_true(self):
""" """
Test basic convs with True. Test basic convs with True.
""" """
self.validate((3,2,7,5), (5,2,2,3), 'valid', unroll_patch=True) self.validate((3, 2, 7, 5), (5, 2, 2, 3), 'valid', unroll_patch=True)
self.validate((3,2,7,5), (5,2,2,3), 'full', unroll_patch=True) self.validate((3, 2, 7, 5), (5, 2, 2, 3), 'full', unroll_patch=True)
self.validate((3,2,3,3), (4,2,3,3), 'valid', unroll_patch=True, verify_grad=False) self.validate((3, 2, 3, 3), (4, 2, 3, 3), 'valid',
unroll_patch=True, verify_grad=False)
def test_unroll_patch_false(self): def test_unroll_patch_false(self):
""" """
Test basic convs with False. Test basic convs with False.
""" """
self.validate((3,2,7,5), (5,2,2,3), 'valid', unroll_patch=False) self.validate((3, 2, 7, 5), (5, 2, 2, 3), 'valid', unroll_patch=False)
self.validate((3,2,7,5), (5,2,2,3), 'full', unroll_patch=False) self.validate((3, 2, 7, 5), (5, 2, 2, 3), 'full', unroll_patch=False)
self.validate((3,2,3,3), (4,2,3,3), 'valid', unroll_patch=False, verify_grad=False) self.validate((3, 2, 3, 3), (4, 2, 3, 3), 'valid',
unroll_patch=False, verify_grad=False)
def test_unroll_patch_true_fail(self): def test_unroll_patch_true_fail(self):
""" """
Test basic convs with True. Test basic convs with True.
""" """
self.validate((3,2,7,5), (5,2,2,3), 'valid', unroll_patch=True, self.validate((3, 2, 7, 5), (5, 2, 2, 3), 'valid', unroll_patch=True,
N_image_shape=(1,3,3,3), N_filter_shape=(6,3,2,2), should_raise=True) N_image_shape=(1, 3, 3, 3), N_filter_shape=(6, 3, 2, 2),
self.validate((3,2,7,5), (5,2,2,3), 'full', unroll_patch=True, should_raise=True)
N_image_shape=(1,3,3,3), N_filter_shape=(6,3,2,2), should_raise=True) self.validate((3, 2, 7, 5), (5, 2, 2, 3), 'full', unroll_patch=True,
self.validate((3,2,3,3), (4,2,3,3), 'valid', unroll_patch=True, N_image_shape=(1, 3, 3, 3), N_filter_shape=(6, 3, 2, 2),
N_image_shape=(1,3,3,3), N_filter_shape=(6,3,2,2), should_raise=True) should_raise=True)
self.validate((3, 2, 3, 3), (4, 2, 3, 3), 'valid', unroll_patch=True,
N_image_shape=(1, 3, 3, 3), N_filter_shape=(6, 3, 2, 2),
should_raise=True)
def test_unroll_special(self): def test_unroll_special(self):
""" """
(unroll_kern, unroll_batch) in (0,1),(1,0) is special case. (unroll_kern, unroll_batch) in (0,1),(1,0) is special case.
""" """
self.validate((6,2,3,3), (3,2,2,2), 'valid', unroll_batch=1) self.validate((6, 2, 3, 3), (3, 2, 2, 2), 'valid', unroll_batch=1)
def test_unroll_batch(self): def test_unroll_batch(self):
""" """
Test mini-batch unrolling for various legal values. Test mini-batch unrolling for various legal values.
""" """
# mini-batch of size 6 is multiple of 2 and 3. Should work. # mini-batch of size 6 is multiple of 2 and 3. Should work.
self.validate((6,2,3,3), (3,2,2,2), 'valid', unroll_batch=2, verify_grad=False) self.validate((6, 2, 3, 3), (3, 2, 2, 2), 'valid',
self.validate((6,2,3,3), (3,2,2,2), 'valid', unroll_batch=3, verify_grad=False) unroll_batch=2, verify_grad=False)
self.validate((6, 2, 3, 3), (3, 2, 2, 2), 'valid',
unroll_batch=3, verify_grad=False)
def test_unroll_kern(self): def test_unroll_kern(self):
""" """
Test kernel unrolling for various legal values. Test kernel unrolling for various legal values.
""" """
# 6 filters is a multiple of 2 and 3. Should work. # 6 filters is a multiple of 2 and 3. Should work.
self.validate((2,3,3,3), (6,3,2,2), 'valid', unroll_kern=2, verify_grad=False) self.validate((2, 3, 3, 3), (6, 3, 2, 2), 'valid', unroll_kern=2,
self.validate((2,3,3,3), (6,3,2,2), 'valid', unroll_kern=3, verify_grad=False) verify_grad=False)
self.validate((2, 3, 3, 3), (6, 3, 2, 2), 'valid', unroll_kern=3,
verify_grad=False)
def test_unroll_batch_kern(self): def test_unroll_batch_kern(self):
""" """Test mini-batch unrolling with kernel unrolling for various
Test mini-batch unrolling with kernel unrolling for various legal values. legal values.
""" """
# mini-batch of size 6 is multiple of 2 and 3. Should work. # mini-batch of size 6 is multiple of 2 and 3. Should work.
self.validate((6,2,3,3), (3,2,2,2), 'valid', unroll_batch=2, unroll_kern=3, verify_grad=False) self.validate((6, 2, 3, 3), (3, 2, 2, 2), 'valid',
self.validate((6,2,3,3), (3,2,2,2), 'valid', unroll_batch=3, unroll_kern=3, verify_grad=False) unroll_batch=2, unroll_kern=3, verify_grad=False)
self.validate((6, 2, 3, 3), (3, 2, 2, 2), 'valid',
unroll_batch=3, unroll_kern=3, verify_grad=False)
# 6 filters is a multiple of 2 and 3. Should work. # 6 filters is a multiple of 2 and 3. Should work.
self.validate((2,3,3,3), (6,3,2,2), 'valid', unroll_batch=2, unroll_kern=2, verify_grad=False) self.validate((2, 3, 3, 3), (6, 3, 2, 2), 'valid',
self.validate((2,3,3,3), (6,3,2,2), 'valid', unroll_batch=2, unroll_kern=3, verify_grad=False) unroll_batch=2, unroll_kern=2, verify_grad=False)
self.validate((2, 3, 3, 3), (6, 3, 2, 2), 'valid',
unroll_batch=2, unroll_kern=3, verify_grad=False)
def test_unroll_batch_kern_fail(self): def test_unroll_batch_kern_fail(self):
""" """Test mini-batch unrolling with kernel unrolling for various
Test mini-batch unrolling with kernel unrolling for various legal values, but pass bad input. legal values, but pass bad input. All those test must
All those test must generate errors generate errors
""" """
# mini-batch of size 6 is multiple of 2 and 3. Should work. # mini-batch of size 6 is multiple of 2 and 3. Should work.
self.validate((6,2,3,3), (3,2,2,2), 'valid', unroll_batch=2, unroll_kern=3, self.validate((6, 2, 3, 3), (3, 2, 2, 2), 'valid',
N_image_shape=(7,2,3,3), N_filter_shape=(3,2,2,2), should_raise=True) unroll_batch=2, unroll_kern=3,
self.validate((6,2,3,3), (3,2,2,2), 'valid', unroll_batch=3, unroll_kern=3, N_image_shape=(7, 2, 3, 3), N_filter_shape=(3, 2, 2, 2),
N_image_shape=(6,2,3,3), N_filter_shape=(4,2,2,2), should_raise=True) should_raise=True)
self.validate((2,3,3,3), (6,3,2,2), 'valid', unroll_batch=2, unroll_kern=2, self.validate((6, 2, 3, 3), (3, 2, 2, 2), 'valid',
N_image_shape=(1,3,3,3), N_filter_shape=(6,3,2,2), should_raise=True) unroll_batch=3, unroll_kern=3,
self.validate((2,3,3,3), (6,3,2,2), 'valid', unroll_batch=2, unroll_kern=3, N_image_shape=(6, 2, 3, 3), N_filter_shape=(4, 2, 2, 2),
N_image_shape=(2,3,3,3), N_filter_shape=(5,3,2,2), should_raise=True) should_raise=True)
self.validate((2, 3, 3, 3), (6, 3, 2, 2), 'valid',
unroll_batch=2, unroll_kern=2,
N_image_shape=(1, 3, 3, 3), N_filter_shape=(6, 3, 2, 2),
should_raise=True)
self.validate((2, 3, 3, 3), (6, 3, 2, 2), 'valid',
unroll_batch=2, unroll_kern=3,
N_image_shape=(2, 3, 3, 3), N_filter_shape=(5, 3, 2, 2),
should_raise=True)
def test_subsample(self): def test_subsample(self):
""" """
Tests convolution where subsampling != (1,1) Tests convolution where subsampling != (1,1)
""" """
self.validate((3,2,7,5), (5,2,2,3), 'valid', subsample=(2,2)) self.validate((3, 2, 7, 5), (5, 2, 2, 3), 'valid', subsample=(2, 2))
self.validate((3,2,7,5), (5,2,2,3), 'full', subsample=(2,2)) self.validate((3, 2, 7, 5), (5, 2, 2, 3), 'full', subsample=(2, 2))
self.validate((3,2,7,5), (5,2,2,3), 'valid', subsample=(2,1)) self.validate((3, 2, 7, 5), (5, 2, 2, 3), 'valid', subsample=(2, 1))
# Fails as of 2012-04-12 # Fails as of 2012-04-12
self.validate((1,1,6,6), (1,1,3,3), 'valid', subsample=(3,3)) self.assertRaises(NotImplementedError, self.validate, (1, 1, 6, 6),
(1, 1, 3, 3), 'valid', subsample=(3, 3))
def test_shape_Constant_tensor(self): def test_shape_Constant_tensor(self):
""" """
Tests convolution where the {image,filter}_shape is a Constant tensor. Tests convolution where the {image,filter}_shape is a Constant tensor.
""" """
as_t=T.as_tensor_variable as_t = T.as_tensor_variable
self.validate((as_t(3),as_t(2),as_t(7),as_t(5)), (5,2,2,3), 'valid') self.validate((as_t(3), as_t(2), as_t(7), as_t(5)), (5, 2,
self.validate(as_t([3,2,7,5]), (5,2,2,3), 'valid') 2, 3), 'valid')
self.validate(as_t((3,2,7,5)), (5,2,2,3), 'valid') self.validate(as_t([3, 2, 7, 5]), (5, 2, 2, 3), 'valid')
self.validate((3,2,7,5), (as_t(5),as_t(2),as_t(2),as_t(3)), 'valid') self.validate(as_t((3, 2, 7, 5)), (5, 2, 2, 3), 'valid')
self.validate((3,2,7,5), as_t([5,2,2,3]), 'valid') self.validate((3, 2, 7, 5), (as_t(5), as_t(2), as_t(2),
self.validate((3,2,7,5), as_t((5,2,2,3)), 'valid') as_t(3)), 'valid')
self.validate(as_t([3,2,7,5]), as_t([5,2,2,3]), 'full') self.validate((3, 2, 7, 5), as_t([5, 2, 2, 3]), 'valid')
self.validate((3, 2, 7, 5), as_t((5, 2, 2, 3)), 'valid')
self.validate(as_t([3, 2, 7, 5]), as_t([5, 2, 2, 3]), 'full')
def test_invalid_filter_shape(self): def test_invalid_filter_shape(self):
""" """
Tests scenario where filter_shape[1] != input_shape[1] Tests scenario where filter_shape[1] != input_shape[1]
""" """
self.assertRaises(AssertionError, self.validate, (3,2,8,8), (4,3,5,5), self.assertRaises(AssertionError, self.validate,
(3, 2, 8, 8), (4, 3, 5, 5),
'valid') 'valid')
def test_invalid_input_shape(self): def test_invalid_input_shape(self):
...@@ -291,23 +328,24 @@ class TestConv2D(unittest.TestCase): ...@@ -291,23 +328,24 @@ class TestConv2D(unittest.TestCase):
Test convolutions for various pieces of missing info. Test convolutions for various pieces of missing info.
""" """
self.validate(None, None, self.validate(None, None,
N_image_shape=(3,2,8,8), N_image_shape=(3, 2, 8, 8),
N_filter_shape=(4,2,5,5)) N_filter_shape=(4, 2, 5, 5))
self.validate((3,2,None,None), None, self.validate((3, 2, None, None), None,
N_image_shape=(3,2,8,8), N_image_shape=(3, 2, 8, 8),
N_filter_shape=(4,2,5,5)) N_filter_shape=(4, 2, 5, 5))
self.validate((None,2,None,None), (None,2,5,5), self.validate((None, 2, None, None), (None, 2, 5, 5),
N_image_shape=(3,2,8,8), N_image_shape=(3, 2, 8, 8),
N_filter_shape=(4,2,5,5)) N_filter_shape=(4, 2, 5, 5))
def test_full_mode(self): def test_full_mode(self):
""" """
Tests basic convolution in full mode and case where filter Tests basic convolution in full mode and case where filter
is larger than the input image. is larger than the input image.
""" """
self.validate((3,2,5,5), (4,2,8,8), 'full') self.validate((3, 2, 5, 5), (4, 2, 8, 8), 'full')
def f(): def f():
self.validate((3,2,5,5), (4,2,8,8), 'valid') self.validate((3, 2, 5, 5), (4, 2, 8, 8), 'valid')
self.assertRaises(Exception, f) self.assertRaises(Exception, f)
def test_wrong_input(self): def test_wrong_input(self):
...@@ -328,4 +366,5 @@ class TestConv2D(unittest.TestCase): ...@@ -328,4 +366,5 @@ class TestConv2D(unittest.TestCase):
crashed in this following case. I changed the c code to don't hit crashed in this following case. I changed the c code to don't hit
gcc bug. So it should not crash anymore gcc bug. So it should not crash anymore
""" """
self.validate((1,10,213,129), (46,10,212,1), 'valid', verify_grad=False) self.validate((1, 10, 213, 129), (46, 10, 212, 1), 'valid',
verify_grad=False)
import unittest
import theano import theano
from theano.updates import Updates from theano.updates import Updates
import theano.tensor as T import theano.tensor as T
def test_updates_setitem(): class test_ifelse(unittest.TestCase):
ok = True
up = Updates() def test_updates_init(self):
sv = theano.shared('asdf') self.assertRaises(TypeError, Updates, dict(d=3))
# keys have to be SharedVariables sv = theano.shared('asdf')
try: Updates({sv:3})
up[5] = 7
ok = False
except TypeError:
ok = True
assert ok
# keys have to be SharedVariables def test_updates_setitem(self):
try:
up[T.vector()] = 7
ok = False
except TypeError:
ok = True ok = True
assert ok
# keys have to be SharedVariables
up[theano.shared(88)] = 7
def test_updates_add(): up = Updates()
sv = theano.shared('asdf')
up1 = Updates() # keys have to be SharedVariables
up2 = Updates() self.assertRaises(TypeError, up.__setitem__, 5, 7)
self.assertRaises(TypeError, up.__setitem__, T.vector(), 7)
a = theano.shared('a') up[theano.shared(88)] = 7
b = theano.shared('b')
def test_updates_add(self):
assert not up1 + up2 up1 = Updates()
up2 = Updates()
up1[a] = 5 a = theano.shared('a')
b = theano.shared('b')
# test that addition works assert not up1 + up2
assert up1
assert up1 + up2
assert not up2
assert len(up1+up2)==1 up1[a] = 5
assert (up1 + up2)[a] == 5
up2[b] = 7 # test that addition works
assert up1 assert up1
assert up1 + up2 assert up1 + up2
assert up2 assert not up2
assert len(up1+up2)==2 assert len(up1 + up2) == 1
assert (up1 + up2)[a] == 5 assert (up1 + up2)[a] == 5
assert (up1 + up2)[b] == 7
assert a in (up1 + up2) up2[b] = 7
assert b in (up1 + up2) assert up1
assert up1 + up2
assert up2
# this works even though there is a collision assert len(up1 + up2) == 2
# because values all match assert (up1 + up2)[a] == 5
assert len(up1 + up1 + up1)==1 assert (up1 + up2)[b] == 7
up2[a] = 8 # a gets different value in up1 and up2 assert a in (up1 + up2)
try: assert b in (up1 + up2)
up1 + up2
assert 0
except KeyError:
pass
# reassigning to a key works fine right? # this works even though there is a collision
up2[a] = 10 # because values all match
assert len(up1 + up1 + up1) == 1
up2[a] = 8 # a gets different value in up1 and up2
try:
up1 + up2
assert 0
except KeyError:
pass
# reassigning to a key works fine right?
up2[a] = 10
...@@ -19,6 +19,15 @@ class Updates(dict): ...@@ -19,6 +19,15 @@ class Updates(dict):
This mapping supports the use of the "+" operator for the union of updates. This mapping supports the use of the "+" operator for the union of updates.
""" """
def __init__(self, *key, **kwargs):
ret = super(Updates, self).__init__(*key, **kwargs)
for key in self:
if not isinstance(key, SharedVariable):
raise TypeError(
'Updates keys must inherit from SharedVariable',
key)
return ret
def __setitem__(self, key, value): def __setitem__(self, key, value):
if isinstance(key, SharedVariable): if isinstance(key, SharedVariable):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论