提交 32105ec9 authored 作者: James Bergstra's avatar James Bergstra

refactored gemm-related code. added gemm-related optimizations and tests. added…

refactored gemm-related code. added gemm-related optimizations and tests. added specializations for mul
上级 f28d5cb9
...@@ -783,148 +783,6 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -783,148 +783,6 @@ class EquilibriumOptimizer(NavigatorOptimizer):
print >> sys.stderr, "WARNING: EquilibriumOptimizer max'ed out" print >> sys.stderr, "WARNING: EquilibriumOptimizer max'ed out"
class _EquilibriumOptimizer(NavigatorOptimizer):
def __init__(self,
local_optimizers,
failure_callback = None,
max_depth = None,
max_use_ratio = None):
super(EquilibriumOptimizer, self).__init__(
None,
ignore_newtrees = False,
failure_callback = failure_callback)
self.local_optimizers = local_optimizers
self.max_depth = max_depth
self.max_use_ratio = max_use_ratio
self.tracks = defaultdict(list)
self.tracks0 = defaultdict(list)
max_depth = 0
for lopt in local_optimizers:
tracks = lopt.tracks()
for track in tracks:
max_depth = max(max_depth, len(track))
if self.max_depth is not None and max_depth > self.max_depth:
raise ValueError('One of the local optimizers exceeds the maximal depth.')
for i, op in enumerate(track):
if i == 0:
self.tracks0[op].append((track, i, lopt))
self.tracks[op].append((track, i, lopt))
def fetch_tracks(self, op):
return self.tracks[op] + self.tracks[None]
def fetch_tracks0(self, op):
return self.tracks0[op] + self.tracks0[None]
def backtrack(self, node, tasks):
candidates = self.fetch_tracks(node.op)
tracks = []
def filter(node, depth):
new_candidates = []
for candidate in candidates:
track, i, lopt = candidate
if i < depth:
pass
elif track[i-depth] in (None, node.op):
if i == depth:
tasks[node].append(lopt)
else:
tracks.append(candidate)
else:
new_candidates.append(candidate)
return new_candidates
depth = 0
nodes = [node]
while candidates:
for node in nodes:
candidates = filter(node, depth)
depth += 1
_nodes = nodes
nodes = reduce(list.__iadd__,
[reduce(list.__iadd__,
[[n for n, i in out.clients if not isinstance(n, str)] for out in node.outputs],
[]) for node in nodes],
[])
candidates = tracks
tracks = []
def apply(self, env):
tasks = defaultdict(list)
if self.max_use_ratio is not None:
max_uses = self.max_use_ratio * len(env.nodes)
runs = defaultdict(int)
else:
runs = None
def importer(node):
#print 'IMPORTING', node
self.backtrack(node, tasks)
def pruner(node):
try:
del tasks[node]
except KeyError:
pass
def chin(node, i, r, new_r):
if new_r.owner and not r.clients:
self.backtrack(new_r.owner, tasks)
# # == NOT IDEAL == #
# for node in env.nodes:
# importer(node)
for node in env.toposort():
tasks[node].extend(lopt for track, i, lopt in self.fetch_tracks0(node.op))
u = self.attach_updater(env, importer, pruner, chin)
print 'KEYS', map(hash, tasks.keys())
while tasks:
for node in tasks.iterkeys():
todo = tasks.pop(node)
break
for lopt in todo:
if runs is not None and runs[lopt] >= max_uses:
print >>sys.stderr, 'Warning: optimization exceeded its maximal use ratio: %s, %s' % (lopt, max_uses)
continue
success = self.process_node(env, node, lopt)
if success:
if runs is not None: runs[lopt] += 1
break
self.detach_updater(env, u)
# def match(self, node, candidates):
# candidates[:] = [candidate
# for candidate in candidates
# if candidate.current.op is None or candidate.current.op == node.op]
# for candidate in candidates:
# if candidate.current.inputs is not None:
# for in1, in2 in zip(candidate.current.inputs, node.inputs):
# if isinstance(in1, str):
# candidate.match[in1] = in2
# for client in node.clients:
# op = node.op
# patterns = self.pattern_base[(depth, op)].union(self.pattern_base[(depth, WILDCARD)])
# if not patterns:
# return patterns
# return self.match(node, depth + 1).intersection(patterns)
# def backtrack(self, node, q):
# for node2, i in node.clients:
# op2 = node2.op
def keep_going(exc, nav, repl_pairs): def keep_going(exc, nav, repl_pairs):
"""WRITEME""" """WRITEME"""
pass pass
...@@ -1002,5 +860,3 @@ class PureThenInplaceOptimizer(Optimizer): ...@@ -1002,5 +860,3 @@ class PureThenInplaceOptimizer(Optimizer):
if 0:
class _EquilibriumOptimizer(NavigatorOptimizer):
def __init__(self,
local_optimizers,
failure_callback = None,
max_depth = None,
max_use_ratio = None):
super(EquilibriumOptimizer, self).__init__(
None,
ignore_newtrees = False,
failure_callback = failure_callback)
self.local_optimizers = local_optimizers
self.max_depth = max_depth
self.max_use_ratio = max_use_ratio
self.tracks = defaultdict(list)
self.tracks0 = defaultdict(list)
max_depth = 0
for lopt in local_optimizers:
tracks = lopt.tracks()
for track in tracks:
max_depth = max(max_depth, len(track))
if self.max_depth is not None and max_depth > self.max_depth:
raise ValueError('One of the local optimizers exceeds the maximal depth.')
for i, op in enumerate(track):
if i == 0:
self.tracks0[op].append((track, i, lopt))
self.tracks[op].append((track, i, lopt))
def fetch_tracks(self, op):
return self.tracks[op] + self.tracks[None]
def fetch_tracks0(self, op):
return self.tracks0[op] + self.tracks0[None]
def backtrack(self, node, tasks):
candidates = self.fetch_tracks(node.op)
tracks = []
def filter(node, depth):
new_candidates = []
for candidate in candidates:
track, i, lopt = candidate
if i < depth:
pass
elif track[i-depth] in (None, node.op):
if i == depth:
tasks[node].append(lopt)
else:
tracks.append(candidate)
else:
new_candidates.append(candidate)
return new_candidates
depth = 0
nodes = [node]
while candidates:
for node in nodes:
candidates = filter(node, depth)
depth += 1
_nodes = nodes
nodes = reduce(list.__iadd__,
[reduce(list.__iadd__,
[[n for n, i in out.clients if not isinstance(n, str)] for out in node.outputs],
[]) for node in nodes],
[])
candidates = tracks
tracks = []
def apply(self, env):
tasks = defaultdict(list)
if self.max_use_ratio is not None:
max_uses = self.max_use_ratio * len(env.nodes)
runs = defaultdict(int)
else:
runs = None
def importer(node):
#print 'IMPORTING', node
self.backtrack(node, tasks)
def pruner(node):
try:
del tasks[node]
except KeyError:
pass
def chin(node, i, r, new_r):
if new_r.owner and not r.clients:
self.backtrack(new_r.owner, tasks)
# # == NOT IDEAL == #
# for node in env.nodes:
# importer(node)
for node in env.toposort():
tasks[node].extend(lopt for track, i, lopt in self.fetch_tracks0(node.op))
u = self.attach_updater(env, importer, pruner, chin)
print 'KEYS', map(hash, tasks.keys())
while tasks:
for node in tasks.iterkeys():
todo = tasks.pop(node)
break
for lopt in todo:
if runs is not None and runs[lopt] >= max_uses:
print >>sys.stderr, 'Warning: optimization exceeded its maximal use ratio: %s, %s' % (lopt, max_uses)
continue
success = self.process_node(env, node, lopt)
if success:
if runs is not None: runs[lopt] += 1
break
self.detach_updater(env, u)
# def match(self, node, candidates):
# candidates[:] = [candidate
# for candidate in candidates
# if candidate.current.op is None or candidate.current.op == node.op]
# for candidate in candidates:
# if candidate.current.inputs is not None:
# for in1, in2 in zip(candidate.current.inputs, node.inputs):
# if isinstance(in1, str):
# candidate.match[in1] = in2
# for client in node.clients:
# op = node.op
# patterns = self.pattern_base[(depth, op)].union(self.pattern_base[(depth, WILDCARD)])
# if not patterns:
# return patterns
# return self.match(node, depth + 1).intersection(patterns)
# def backtrack(self, node, q):
# for node2, i in node.clients:
# op2 = node2.op
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
from basic import * from basic import *
import opt import opt
import blas
import raw_random import raw_random
from raw_random import \ from raw_random import \
......
差异被折叠。
差异被折叠。
差异被折叠。
from basic import _scal_elemwise #, _transpose_inplace from .basic import _scal_elemwise #, _transpose_inplace
from .. import scalar as scal from .. import scalar as scal
import elemwise import elemwise
from .. import printing from .. import printing
......
差异被折叠。
...@@ -1356,179 +1356,6 @@ class t_dot(unittest.TestCase): ...@@ -1356,179 +1356,6 @@ class t_dot(unittest.TestCase):
#verify_grad(self, dot, [self.rand(), self.rand(2)]) #verify_grad(self, dot, [self.rand(), self.rand(2)])
#verify_grad(self, dot, [self.rand(), self.rand(2,5)]) #verify_grad(self, dot, [self.rand(), self.rand(2,5)])
class t_gemm(unittest.TestCase):
def setUp(self):
numpy.random.seed(44)
_approx_eq.debug = 0
Gemm.debug = False
@staticmethod
def _gemm(z,a,x,y,b):
assert a.shape == ()
assert b.shape == ()
return b * z + a * numpy.dot(x,y)
@staticmethod
def rand(*args):
return numpy.random.rand(*args)
def cmp(self, z, a, x, y, b):
def cmp_linker(z, a, x, y, b, l):
z,a,x,y,b = [numpy.asarray(p) for p in z,a,x,y,b]
z_orig = z.copy()
tz,ta,tx,ty,tb = [as_tensor(p).type() for p in z,a,x,y,b]
f = function([tz,ta,tx,ty,tb], gemm(tz,ta,tx,ty,tb), mode=compile.Mode(optimizer = None, linker = l))
new_z = f(z,a,x,y,b)
z_after = self._gemm(z_orig, a, x, y, b)
self.failUnless(z is new_z)
#print z_orig, z_after, z, type(z_orig), type(z_after), type(z)
#_approx_eq.debug = 1
self.failUnless(_approx_eq(z_after, z))
if a == 0.0 and b == 1.0:
return
else:
self.failIf(numpy.all(z_orig == z))
cmp_linker(copy(z), a, x, y, b, 'c|py')
cmp_linker(copy(z), a, x, y, b, 'c')
cmp_linker(copy(z), a, x, y, b, 'py')
def test0a(self):
Gemm.debug = True
try:
g = gemm([1.], 1., [1.], [1.], 1.)
except ValueError, e:
if e[0] is Gemm.E_rank:
return
self.fail()
def test0(self):
try:
self.cmp(1., 0., 1.0, 1.0, 1.0)
except ValueError, e:
if e[0] is Gemm.E_rank:
return
self.fail()
def test2(self):
try:
self.cmp(2., 1.0, [3,2,1.], [[1],[2],[3.]], 1.0)
except ValueError, e:
self.failUnless(e[0] == Gemm.E_rank)
return
self.fail()
def test4(self):
self.cmp(self.rand(3,4), 1.0, self.rand(3,5), self.rand(5,4), 0.0)
def test5(self): self.cmp(self.rand(3,4), 1.0,
self.rand(3,5), self.rand(5,4), 1.0)
def test6(self): self.cmp(self.rand(3,4), 1.0,
self.rand(3,5), self.rand(5,4), -1.0)
def test7(self): self.cmp(self.rand(3,4), 0.0,
self.rand(3,5), self.rand(5,4), 0.0)
def test8(self): self.cmp(self.rand(3,4), 0.0,
self.rand(3,5), self.rand(5,4), 0.6)
def test9(self): self.cmp(self.rand(3,4), 0.0,
self.rand(3,5), self.rand(5,4), -1.0)
def test10(self):
_approx_eq.debug = 1
self.cmp(self.rand(3,4), -1.0, self.rand(3,5), self.rand(5,4), 0.0)
def test11(self): self.cmp(self.rand(3,4), -1.0,
self.rand(3,5), self.rand(5,4), 1.0)
def test12(self): self.cmp(self.rand(3,4), -1.0,
self.rand(3,5), self.rand(5,4), -1.0)
def test_destroy_map0(self):
"""test that only first input can be overwritten"""
Z = as_tensor(self.rand(2,2))
try:
gemm(Z, 1.0, Z, Z, 1.0)
except ValueError, e:
if e[0] == Gemm.E_z_uniq:
return
self.fail()
def test_destroy_map1(self):
"""test that only first input can be overwritten"""
Z = as_tensor(self.rand(2,2))
A = as_tensor(self.rand(2,2))
try:
gemm(Z, 1.0, A, inplace.transpose_inplace(Z), 1.0)
except ValueError, e:
if e[0] == Gemm.E_z_uniq:
return
self.fail()
def test_destroy_map2(self):
"""test that only first input can be overwritten"""
Z = as_tensor(self.rand(2,2))
A = as_tensor(self.rand(2,2))
try:
gemm(Z, 1.0, inplace.transpose_inplace(Z), A, 1.0)
except ValueError, e:
if e[0] == Gemm.E_z_uniq:
return
self.fail()
def test_destroy_map3(self):
"""test that only first input can be overwritten"""
Z = as_tensor(self.rand(2,2))
A = as_tensor(self.rand(2,2))
try:
gemm(Z, 1.0, Z, A, 1.0)
except ValueError, e:
if e[0] == Gemm.E_z_uniq:
return
self.fail()
def test_destroy_map4(self):
"""test that dot args can be aliased"""
Z = value(self.rand(2,2))
A = value(self.rand(2,2))
eval_outputs([gemm(Z, 1.0, A, A, 1.0)])
eval_outputs([gemm(Z, 1.0, A, A.T, 1.0)])
def test_transposes(self):
# three square matrices which are not contiguous
A = self.rand(4,5)[:,:4]
B = self.rand(4,5)[:,:4]
C = self.rand(4,5)[:,:4]
def t(z,x,y,a=1.0, b=0.0,l='c|py',dt='float64'):
z,a,x,y,b = [numpy.asarray(p,dtype=dt) for p in z,a,x,y,b]
z_orig = z.copy()
z_after = self._gemm(z, a, x, y, b)
tz,ta,tx,ty,tb = [value(p) for p in z,a,x,y,b]
f = function([tz,ta,tx,ty,tb], gemm(tz,ta,tx,ty,tb), mode = compile.Mode(optimizer = None, linker=l))
f(z, a, x, y, b)
self.failUnless(_approx_eq(z_after, z), (z_orig, z_after, z, z_after - z))
f(z.T, a, y.T, x.T, b)
self.failUnless(_approx_eq(z_after, z))
t(C,A,B)
t(C.T, A, B)
t(C, A.T, B, dt='float32')
t(C, A, B.T)
t(C.T, A.T, B)
t(C, A.T, B.T, dt='float32')
t(C.T, A, B.T)
t(C.T, A.T, B.T, dt='float32')
t(C, A[:,:2], B[:2, :])
t(C.T, A[:,:2], B[:2, :], dt='float32')
t(C, A[:2,:].T, B[:2, :])
t(C.T, A[:2,:].T, B[:2, :], dt='float32')
t(C, A[:2,:].T, B[:, :2].T)
t(C.T, A[:2,:].T, B[:, :2].T)
try:
t(C.T, A[:2,:], B[:, :2].T)
except ValueError, e:
if e[0].find('aligned') >= 0:
return
self.fail()
class T_tensorfromscalar(unittest.TestCase): class T_tensorfromscalar(unittest.TestCase):
def test0(self): def test0(self):
s = scal.constant(56) s = scal.constant(56)
......
差异被折叠。
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论