提交 49587bc3 authored 作者: James Bergstra's avatar James Bergstra

Added a Ger Op with tests to tensor/blas.py

上级 c9a0def3
...@@ -211,6 +211,84 @@ class Gemv(Op): ...@@ -211,6 +211,84 @@ class Gemv(Op):
gemv_no_inplace = Gemv(inplace=False) gemv_no_inplace = Gemv(inplace=False)
gemv_inplace = Gemv(inplace=True) gemv_inplace = Gemv(inplace=True)
class Ger(Op):
"""
BLAS defines general rank-1 update GER as A <- A + alpha x y'
for matrix A, scalar alpha, vectors x and y.
This interface to GER allows non-destructive operation on A via the
`destructive`
argument to the constructor.
:TODO: Create better classes ScipyGer and CGer that inherit from this class
and override the make_thunk() method to use Scipy and C respectively.
"""
def __init__(self, destructive):
self.destructive=destructive
if destructive:
self.destroy_map={0:[0]}
def __eq__(self, other):
return type(self)==type(other) and self.destructive == other.destructive
def __hash__(self):
return hash(type(self)) ^ hash(self.destructive)
def __str__(self):
if self.destructive:
return 'Ger{destructive}'
else:
return 'Ger{non-destructive}'
def make_node(self, A, alpha, x, y):
A = T.as_tensor_variable(A)
y = T.as_tensor_variable(y)
x = T.as_tensor_variable(x)
alpha = T.as_tensor_variable(alpha)
if len(set([A.dtype, alpha.dtype, x.dtype, y.dtype])) != 1:
raise TypeError('ger requires matching dtypes',
(A.dtype, alpha.dtype, x.dtype, y.dtype))
if alpha.ndim != 0:
raise TypeError('ger requires scalar alpha', alpha.type)
if A.ndim != 2:
raise TypeError('ger requires matrix for A', A.type)
if x.ndim != 1:
raise TypeError('ger requires vector for x', x.type)
if y.ndim != 1:
raise TypeError('ger requires vector for y', y.type)
if x.dtype not in ('float32', 'float64', 'complex64', 'complex128'):
raise TypeError('only float and complex types supported', x.dtype)
return Apply(self, [A, alpha, x, y], [A.type()])
def make_thunk(self, node, storage_map, compute_map, no_recycling):
node_input_storage = [storage_map[r] for r in node.inputs]
node_output_storage = [storage_map[r] for r in node.outputs]
# get vars for containers
cA, calpha, cx, cy = node_input_storage
cZ, = node_output_storage
def rval():
A = cA[0]
if self.destructive:
A = cA[0]
else:
A = cA[0].copy()
A += calpha[0] * numpy.outer(cx[0], cy[0])
cZ[0] = A
#TODO: If this is currently an unofficial part of the thunk API,
# then maybe it should be documented and made official?
rval.inputs = node_input_storage
rval.outputs = node_output_storage
rval.lazy = False
return rval
ger = Ger(destructive=False)
ger_destructive = Ger(destructive=True)
def default_blas_ldflags(): def default_blas_ldflags():
try: try:
#if numpy was linked with library that are not installed, we can't reuse them. #if numpy was linked with library that are not installed, we can't reuse them.
......
...@@ -14,10 +14,11 @@ from numpy.testing import assert_, assert_array_almost_equal ...@@ -14,10 +14,11 @@ from numpy.testing import assert_, assert_array_almost_equal
#from theano.tensor.blas import * #from theano.tensor.blas import *
from theano.tensor.blas import (_dot22, _dot22scalar, res_is_a, _as_scalar, _is_real_matrix, from theano.tensor.blas import (_dot22, _dot22scalar, res_is_a, _as_scalar, _is_real_matrix,
_gemm_canonicalize, _factor_canonicalized, Gemm, Gemv, gemm_inplace, gemm_no_inplace, _gemm_canonicalize, _factor_canonicalized, Gemm, Gemv, gemm_inplace, gemm_no_inplace,
InconsistencyError) InconsistencyError,
Ger, ger, ger_destructive)
from unittest import TestCase from unittest import TestCase
from theano.tests import unittest_tools from theano.tests import unittest_tools
from copy import copy from copy import copy, deepcopy
from theano import Param, shared, config from theano import Param, shared, config
from test_basic import (_approx_eq, as_tensor_variable, inplace_func, from test_basic import (_approx_eq, as_tensor_variable, inplace_func,
...@@ -877,6 +878,194 @@ class TestGemv(TestCase): ...@@ -877,6 +878,194 @@ class TestGemv(TestCase):
self.assertRaises(ValueError, f, A_val, ones_3, ones_6) self.assertRaises(ValueError, f, A_val, ones_3, ones_6)
self.assertRaises(ValueError, f, A_val, ones_4, ones_6) self.assertRaises(ValueError, f, A_val, ones_4, ones_6)
class TestGer_make_node(TestCase):
def setUp(self):
self.iv = T.tensor(dtype='int32', broadcastable=(False,))
self.fv = T.tensor(dtype='float32', broadcastable=(False,))
self.fv1 = T.tensor(dtype='float32', broadcastable=(True,))
self.dv = T.tensor(dtype='float64', broadcastable=(False,))
self.dv1 = T.tensor(dtype='float64', broadcastable=(True,))
self.cv = T.tensor(dtype='complex64', broadcastable=(False,))
self.zv = T.tensor(dtype='complex128', broadcastable=(False,))
self.fv_2 = T.tensor(dtype='float32', broadcastable=(False,))
self.fv1_2 = T.tensor(dtype='float32', broadcastable=(True,))
self.dv_2 = T.tensor(dtype='float64', broadcastable=(False,))
self.dv1_2 = T.tensor(dtype='float64', broadcastable=(True,))
self.cv_2 = T.tensor(dtype='complex64', broadcastable=(False,))
self.zv_2 = T.tensor(dtype='complex128', broadcastable=(False,))
self.fm = T.fmatrix()
self.dm = T.dmatrix()
self.cm = T.cmatrix()
self.zm = T.zmatrix()
self.fa = T.fscalar()
self.da = T.dscalar()
self.ca = T.cscalar()
self.za = T.zscalar()
def test_works_on_all_valid_dtypes(s):
s.assertEquals(s.fm.type,
ger(s.fm, s.fa, s.fv, s.fv_2).type)
s.assertEquals(s.fm.type,
ger(s.fm, s.fa, s.fv, s.fv_2).type)
s.assertEquals(s.fm.type,
ger(s.fm, s.fa, s.fv, s.fv_2).type)
s.assertEquals(s.fm.type,
ger(s.fm, s.fa, s.fv, s.fv_2).type)
def test_fails_on_invalid_matching_dtypes(s):
s.assertRaises(TypeError,
ger, T.imatrix(), T.iscalar(), T.ivector(),
T.ivector())
def test_fails_for_nonscalar_alpha(s):
s.assertRaises(TypeError,
ger, s.fm, s.fm, s.fv, s.fv_2)
# boundary case - fv1 has the right dtype and could be dimshuffled to a
# scalar, but that's not make_node's job.
s.assertRaises(TypeError,
ger, s.fm, s.fv1, s.fv, s.fv_2)
# actually doing the aforementioned dimshuffle makes it work
s.assertEquals(s.fm.type,
ger(s.fm, s.fv1.dimshuffle(), s.fv, s.fv_2).type)
def test_fails_for_nonmatrix_A(s):
s.assertRaises(TypeError,
ger, s.fv, s.fa, s.fv, s.fv_2)
def test_fails_for_nonvector_x_or_y(s):
s.assertRaises(TypeError,
ger, s.fm, s.fa, s.fv.dimshuffle('x', 0), s.fv_2)
s.assertRaises(TypeError,
ger, s.fm, s.fa, s.fv, s.fv_2.dimshuffle('x', 0))
def test_fails_for_mixed_dtypes(s):
s.assertRaises(TypeError, ger, s.dm, s.fa, s.fv, s.fv_2)
s.assertRaises(TypeError, ger, s.fm, s.da, s.fv, s.fv_2)
s.assertRaises(TypeError, ger, s.fm, s.fa, s.dv, s.fv_2)
s.assertRaises(TypeError, ger, s.fm, s.fa, s.fv, s.dv_2)
s.assertRaises(TypeError, ger, s.cm, s.fa, s.fv, s.dv_2)
s.assertRaises(TypeError, ger, s.cm, s.fa, s.fv, s.zv_2)
# TODO: refactor this into some place where all OpTesters could use it.
class TestOpContractMixin(object):
# self.ops should be a list of instantiations of an Op class to test.
# self.other_op should be an op which is different from every op
other_op = T.add
def copy(self, x):
return copy(x)
def deepcopy(self, x):
return deepcopy(x)
def clone(self, op):
raise NotImplementedError('return new instance like `op`')
def test_eq_ger(self):
for i, op_i in enumerate(self.ops):
assert op_i == op_i
assert op_i == self.copy(op_i)
assert op_i == self.deepcopy(op_i)
assert op_i == self.clone(op_i)
assert op_i != self.other_op
for j, op_j in enumerate(self.ops):
if i == j: continue
assert op_i != op_j
def test_hash(self):
for i, op_i in enumerate(self.ops):
h_i = hash(op_i)
assert h_i == hash(op_i)
assert h_i == hash(self.copy(op_i))
assert h_i == hash(self.deepcopy(op_i))
assert h_i == hash(self.clone(op_i))
assert h_i != hash(self.other_op)
for j, op_j in enumerate(self.ops):
if i == j: continue
assert op_i != hash(op_j)
def test_name(self):
for op in self.ops:
s = str(op) # show that str works
assert s # names should not be empty
class TestGer_OpContract(TestCase, TestOpContractMixin):
#TODO: These tests could be factored into a generic Op-testing base-class
def setUp(self):
self.ops = [ger, ger_destructive]
def clone(self, op):
return Ger(op.destructive)
class TestGer_make_thunk(TestCase):
def setUp(self):
self.rng = numpy.random.RandomState(unittest_tools.fetch_seed())
def given_dtype(self, dtype, M, N):
sA = T.tensor(dtype=dtype, broadcastable=(False, False))
sa = T.tensor(dtype=dtype, broadcastable=())
sx = T.tensor(dtype=dtype, broadcastable=(False,))
sy = T.tensor(dtype=dtype, broadcastable=(False,))
sZ = ger(sA, sa, sx, sy)
node = sZ.owner
storage_map = {sA:[None], sa:[None], sx:[None], sy:[None], sZ:[None]}
thunk = ger.make_thunk(node, storage_map,
compute_map={}, no_recycling=[])
# non-standard for make_thunk to receive node.op != self,
# but works for now.
thunk_d = ger_destructive.make_thunk(node, storage_map,
compute_map={}, no_recycling=[])
def rand(*shape):
return numpy.asarray(1 + self.rng.rand(*shape), dtype=dtype)
storage_map[sA][0] = rand(M, N)
storage_map[sa][0] = rand()
storage_map[sx][0] = rand(M)
storage_map[sy][0] = rand(N)
storage_map_copy = dict([(k,[deepcopy(v[0])]) for k,v in storage_map.items()])
# TODO: do some DebugMode-type verifications here
# if this can be refactored into a Mixin that does the DebugMode
# stuff on just one thunk at a time. Do it in the style of
# TestOpContractMixin?
# - Compare with Elemwise testers
thunk()
assert numpy.all(storage_map[sZ][0] ==
storage_map[sA][0] + storage_map[sa][0] *
numpy.outer(storage_map[sx][0], storage_map[sy][0]))
assert storage_map[sZ][0].dtype == dtype
assert storage_map[sZ][0].shape == (M, N)
thunk_d()
assert numpy.all(storage_map[sZ][0] !=
storage_map[sA][0] + storage_map[sa][0] *
numpy.outer(storage_map[sx][0], storage_map[sy][0]))
assert numpy.all(storage_map[sZ][0] ==
storage_map_copy[sA][0] + storage_map[sa][0] *
numpy.outer(storage_map[sx][0], storage_map[sy][0]))
assert storage_map[sZ][0].dtype == dtype
assert storage_map[sZ][0].shape == (M, N)
def test_f32_0_0(s): return s.given_dtype('float32', 0, 0)
def test_f32_1_0(s): return s.given_dtype('float32', 1, 0)
def test_f32_0_1(s): return s.given_dtype('float32', 0, 1)
def test_f32_1_1(s): return s.given_dtype('float32', 1, 1)
def test_f32_4_4(s): return s.given_dtype('float32', 4, 4)
def test_f64_4_5(s): return s.given_dtype('float64', 4, 5)
def test_c64_7_1(s): return s.given_dtype('complex64', 7, 1)
def test_c128_1_9(s): return s.given_dtype('complex128', 1, 9)
# The following gemv tests were added in March 2011 by Ian Goodfellow # The following gemv tests were added in March 2011 by Ian Goodfellow
# and are based on the gemv tests from scipy # and are based on the gemv tests from scipy
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论