提交 abd9bef4 authored 作者: James Bergstra's avatar James Bergstra

new GEMM: removed dead code, changed definition of _as_scalar

上级 76a6cd53
...@@ -509,14 +509,21 @@ def res_is_a(node, op, maxclients=None): ...@@ -509,14 +509,21 @@ def res_is_a(node, op, maxclients=None):
def _as_scalar(res): def _as_scalar(res):
"""Return None or a TensorVariable whose type is in T.float_scalar_types""" """Return None or a TensorVariable whose type is in T.float_scalar_types"""
if res.owner and isinstance(res.owner.op, T.DimShuffle): if numpy.all(res.type.broadcastable):
return _as_scalar(res.owner.inputs[0]) while res.owner and isinstance(res.owner.op, T.DimShuffle):
elif res.type in T.float_scalar_types: res = res.owner.inputs[0]
return res if res.type.broadcastable: # may still have some number of True's
elif isinstance(res, T.Constant) and res.data.size == 1: rval = res.dimshuffle()
return res.data.flatten()[0] else:
else: rval = res
return None
if rval.type.dtype[:3] in ('int', 'uin'):
rval = cast(rval, theano.config.floatX) #may lose precision !?
#if isinstance(rval, T.Constant):
#rval = rval.data.flatten()[0]
return rval
def _is_real_matrix(res): def _is_real_matrix(res):
return res.type.dtype in ('float32', 'float64') \ return res.type.dtype in ('float32', 'float64') \
...@@ -524,39 +531,6 @@ def _is_real_matrix(res): ...@@ -524,39 +531,6 @@ def _is_real_matrix(res):
and res.type.broadcastable[0] == False \ and res.type.broadcastable[0] == False \
and res.type.broadcastable[1] == False #cope with tuple vs. list and res.type.broadcastable[1] == False #cope with tuple vs. list
def _as_isolated_scalar_times_matrix(res):
"""Returns (scalar_var, matrix_var) on success else None
"""
# isolated means that there is only one client of the result 'res'
if res_is_a(res, T.mul, 1):
if len(res.owner.inputs) == 2:
L, R = res.owner.inputs
sL = _as_scalar(L)
sR = _as_scalar(R)
if (sL is not None) and _is_real_matrix(R):
return (sL, R)
if (sR is not None) and _is_real_matrix(L):
return (sR, L)
else:
scalars = []
matrices = []
for input in res.owner.inputs:
scalar_input = _as_scalar(input)
if scalar_input is not None:
scalars.append(scalar_input)
elif _is_real_matrix(input):
matrices.append(input)
else:
return None
if len(matrices) == 1:
if len(scalars) == 0:
rval = (1.0, matrices[0])
elif len(scalars) == 1:
rval = (scalars[0], matrices[0])
else:
rval = (T.mul(*scalars), matrices[0])
return rval
def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True): def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True):
#print 'BETA L + ALPHA M', beta, L, alpha, M, recurse_flip #print 'BETA L + ALPHA M', beta, L, alpha, M, recurse_flip
#EXPRESSION: (beta * L) + (alpha * M) #EXPRESSION: (beta * L) + (alpha * M)
...@@ -609,6 +583,10 @@ def _gemm_canonicalize(r, scale, rval, maxclients): ...@@ -609,6 +583,10 @@ def _gemm_canonicalize(r, scale, rval, maxclients):
return -thing return -thing
else: else:
return scale*thing return scale*thing
try:
r.type.broadcastable
except:
return None
if (tuple(r.type.broadcastable) != (False, False) or if (tuple(r.type.broadcastable) != (False, False) or
r.type.dtype not in ('float32', 'float64', 'complex64', 'complex128')): r.type.dtype not in ('float32', 'float64', 'complex64', 'complex128')):
...@@ -748,119 +726,6 @@ def _gemm_from_node2(node): ...@@ -748,119 +726,6 @@ def _gemm_from_node2(node):
rval = _gemm_from_factored_list(lst) rval = _gemm_from_factored_list(lst)
return rval return rval
def inputs_as_scalar_times_matrix(node):
# try to interpret an expression as a sum of scalar * matrix terms plus an 'other' term.
# This function *could* recurse and flatten sub and add hierarchies, but it doesn't.
# Reason being - if we didn't need intermediate results, the canonizer should already done
# that.
# returns three lists: sM_list, sM_orig, other
# - sM_list is a list of pairs: the interpretation of some terms as scalar,matrix products
# - sM_orig is a list of variables: the originals before interpretation into sM_list
# - other is a list of terms that are not float matrices
op = None
sM_list = []
sM_orig = []
other = []
if node.op == T.add or node.op == T.sub:
op = node.op
for input in node.inputs:
tmp = _as_isolated_scalar_times_matrix(input)
if tmp:
sM_list.append(tmp)
sM_orig.append(input)
elif _is_real_matrix(input):
sM_list.append((1.0, input))
sM_orig.append(input)
else:
other.append(input)
assert len(sM_list) == len(sM_orig)
assert len(sM_list) + len(other) == len(node.inputs)
return op, sM_list, sM_orig, other
def _gemm_from_sM_list(node, sM_list, sM_orig, other_inputs):
"""Returns None, or a list to replace node.outputs
"""
if len(sM_list) == 2:
(sL, mL), (sR, mR) = sM_list
gemm_of_sM_list = _beta_L_plus_alpha_M(sL, mL, sR, mR)
if gemm_of_sM_list:
#we turned the two candidates into a gemm
# now we have to add the other_inputs and return the replacement graph
if other_inputs:
return [T.add(*(other_inputs + gemm_of_sM_list))]
else:
return gemm_of_sM_list
else:
# Try every pair in the sM_list, trying to turn it into a gemm operation
for i in xrange(len(sM_list) - 1):
for j in xrange(i+1, len(sM_list)):
assert i != j
sL, mL = sM_list[i]
sR, mR = sM_list[j]
gemm_of_sM_list = _beta_L_plus_alpha_M(sL, mL, sR, mR)
if gemm_of_sM_list:
assert len(gemm_of_sM_list) == 1
inputs_without_ij = [input for k, input in enumerate(sM_orig) if k not in (i,j)]
new_add_inputs = (inputs_without_ij + gemm_of_sM_list + other_inputs)
# this should be True because we've combined a pair of arguments
# into a single GEMM
assert len(new_add_inputs) + 1 == len(node.inputs)
return [T.add(*new_add_inputs)]
def _gemm_from_node(node):
"""
:todo: In many expressions, there are many ways to turn it into a gemm. For example
dot(a,b) + c + d. This function should return all of them, so that if one version of gemm
causes a cycle in the graph, then another application of gemm can be tried.
"""
op, sM_list, sM_orig, other_inputs = inputs_as_scalar_times_matrix(node)
if op == T.sub and len(sM_list)==2:
(sL, mL), (sR,mR) = sM_list
rval = _gemm_from_sM_list([(sL, mL), (-sR,mR)], None, None)
if rval:
return rval
#theano.printing.debugprint(node.outputs[0], depth=6)
if len(sM_orig[1].clients)==1:
# Canonicalize this subgraph
# There is a form of Gemm that escapes the approach above
# g*W - (a * (e*dot(b,c) + d * W + X))
#
# -> gemm(W, -a*e, b, c, g-a*d) - a*X
#
# In this case g=sL W=mL, and a=sR. We must see if mR is a add() or a sub, in which
# one of the arguments is a scaled version of W a.k.a mL
Rop, RsM_list, RsM_orig, Rother_inputs = inputs_as_scalar_times_matrix(mR.owner)
RsM_list_that_is_mL = [s for (s,m) in RsM_list if m is mL]
if RsM_list_that_is_mL and Rop == T.add:
pass
#g= sL - T.mul(sR,*RsM_list_that_is_mL)
#rval = _gemm_from_sM_list(
#[(g,mL)] + []]
#]
#)
#if Rop == T.add:
#rval = _beta_L_plus_alpha_M(
#L=mL,
#alpha=sR,
#R=T.)
return rval
if op == T.add:
return _gemm_from_sM_list(sM_list, sM_orig, other_inputs)
class GemmOptimizer(Optimizer): class GemmOptimizer(Optimizer):
"""Graph optimizer for inserting Gemm operations""" """Graph optimizer for inserting Gemm operations"""
def __init__(self): def __init__(self):
...@@ -1136,3 +1001,5 @@ def local_dot22_to_dot22scalar(node): ...@@ -1136,3 +1001,5 @@ def local_dot22_to_dot22scalar(node):
blas_optdb.register('local_dot22_to_dot22scalar', blas_optdb.register('local_dot22_to_dot22scalar',
EquilibriumOptimizer([local_dot22_to_dot22scalar ], max_use_ratio=5), EquilibriumOptimizer([local_dot22_to_dot22scalar ], max_use_ratio=5),
11, 'fast_run') 11, 'fast_run')
...@@ -217,15 +217,17 @@ class t_as_scalar(TestCase): ...@@ -217,15 +217,17 @@ class t_as_scalar(TestCase):
"""Test that it works on scalar constants""" """Test that it works on scalar constants"""
a = T.constant(2.5) a = T.constant(2.5)
b = T.constant(numpy.asarray([[[0.5]]])) b = T.constant(numpy.asarray([[[0.5]]]))
b2 = b.dimshuffle()
assert b2.ndim == 0
d_a = T.DimShuffle([], [])(a) d_a = T.DimShuffle([], [])(a)
d_b = T.DimShuffle([True, True, True], [0,2,1])(b) d_b = T.DimShuffle([True, True, True], [0,2,1])(b)
d_a2 = T.DimShuffle([], ['x', 'x', 'x'])(a) d_a2 = T.DimShuffle([], ['x', 'x', 'x'])(a)
self.failUnless(numpy.all(_as_scalar(a) == a)) self.failUnless(_as_scalar(a) == a)
self.failUnless(numpy.all(_as_scalar(b) == b.data), (b, _as_scalar(b))) self.failUnless(_as_scalar(b) != b)
self.failUnless(numpy.all(_as_scalar(d_a) == a)) self.failUnless(_as_scalar(d_a) != d_a)
self.failUnless(numpy.all(_as_scalar(d_b) == b.data)) self.failUnless(_as_scalar(d_b) != d_b)
self.failUnless(numpy.all(_as_scalar(d_a2) == a)) self.failUnless(_as_scalar(d_a2) != d_a2)
def test1(self): def test1(self):
"""Test that it fails on nonscalar constants""" """Test that it fails on nonscalar constants"""
...@@ -432,6 +434,7 @@ def test_gemm_opt_wishlist(): ...@@ -432,6 +434,7 @@ def test_gemm_opt_wishlist():
#with >2 additions of the same T.dot(X,Y term #with >2 additions of the same T.dot(X,Y term
just_gemm([X,Y,Z,a,b], [(b * b) * Z * a + (a * a) * T.dot(X,Y) + b * T.dot(X,Y)]) just_gemm([X,Y,Z,a,b], [(b * b) * Z * a + (a * a) * T.dot(X,Y) + b * T.dot(X,Y)])
just_gemm([X,Y,Z,a,b], [Z + T.dot(X,Y) + T.dot(X,Y)]) just_gemm([X,Y,Z,a,b], [Z + T.dot(X,Y) + T.dot(X,Y)])
def test_gemm_with_vector(): def test_gemm_with_vector():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论