added theano.opt.DimShuffleLifter, theano.opt.find_cliques

上级 4c7455c9
import unittest
import gof
from opt import *
import tensor
from tensor import Tensor
from gof import Env
from elemwise import DimShuffle
import numpy
def inputs(xbc = (0, 0), ybc = (0, 0), zbc = (0, 0)):
x = Tensor(broadcastable = xbc, dtype = 'float64', name = 'x')
y = Tensor(broadcastable = ybc, dtype = 'float64', name = 'y')
z = Tensor(broadcastable = zbc, dtype = 'float64', name = 'z')
return x, y, z
ds = gof.op.constructor(DimShuffle)
class _test_inplace_opt(unittest.TestCase):
def test_straightforward(self):
x, y, z = inputs()
e = x + y + z
g = Env([x, y], [e])
assert str(g) == "[Broadcast{Add}(Broadcast{Add}(x, y), z)]"
inplace_optimizer.optimize(g)
assert str(g) == "[Broadcast{Add}{0: 0}(Broadcast{Add}{0: 0}(x, y), z)]"
def test_multiple_uses(self):
x, y, z = inputs()
e0 = x + y
e1 = x * y
g = Env([x, y], [e0, e1])
assert str(g) == "[Broadcast{Add}(x, y), Broadcast{Mul}(x, y)]"
inplace_optimizer.optimize(g)
assert str(g) == "[Broadcast{Add}{0: 0}(x, y), Broadcast{Mul}(x, y)]" \
or str(g) == "[Broadcast{Add}(x, y), Broadcast{Mul}{0: 0}(x, y)]"
def test_user_inplace(self):
x, y, z = inputs()
e0 = x + y
e1 = tensor.mul_inplace(x, y)
g = Env([x, y], [e0, e1])
assert str(g) == "[Broadcast{Add}(x, y), Broadcast{Mul}{0: 0}(x, y)]"
inplace_optimizer.optimize(g)
assert str(g) == "[Broadcast{Add}(x, y), Broadcast{Mul}{0: 0}(x, y)]"
class _test_dimshuffle_lift(unittest.TestCase):
def test_double_transpose(self):
x, y, z = inputs()
e = ds(ds(x, (1, 0)), (1, 0))
g = Env([x], [e])
assert str(g) == "[DimShuffle{10}(DimShuffle{10}(x))]"
lift_dimshuffle.optimize(g)
assert str(g) == "[x]"
def test_merge2(self):
x, y, z = inputs()
e = ds(ds(x, (1, 'x', 0)), (2, 0, 'x', 1))
g = Env([x], [e])
self.failUnless(str(g) == "[DimShuffle{20x1}(DimShuffle{1x0}(x))]", str(g))
lift_dimshuffle.optimize(g)
self.failUnless(str(g) == "[DimShuffle{01xx}(x)]", str(g))
def test_elim3(self):
x, y, z = inputs()
e = ds(ds(ds(x, (0, 'x', 1)), (2, 0, 'x', 1)), (1, 0))
g = Env([x], [e])
self.failUnless(str(g) == "[DimShuffle{10}(DimShuffle{20x1}(DimShuffle{0x1}(x)))]", str(g))
lift_dimshuffle.optimize(g)
self.failUnless(str(g) == "[x]", str(g))
def test_lift(self):
x, y, z = inputs([0]*1, [0]*2, [0]*3)
e = x + y + z
g = Env([x, y, z], [e])
self.failUnless(str(g) == "[Broadcast{Add}(DimShuffle{x01}(Broadcast{Add}(DimShuffle{x0}(x), y)), z)]", str(g))
lift_dimshuffle.optimize(g)
self.failUnless(str(g) == "[Broadcast{Add}(Broadcast{Add}(DimShuffle{xx0}(x), DimShuffle{x01}(y)), z)]", str(g))
class _test_cliques(unittest.TestCase):
def test_straightforward(self):
x, y, z = inputs()
m = y * z
d = tensor.dot(x, m)
d.name = 'd'
e = x + y + d
g = Env([x, y, z], [e])
cliques = find_cliques(g)
assert len(cliques) == 2
(i1, o1), (i2, o2) = cliques
assert str(Env(i1, [o1])) == "[Broadcast{Add}(Broadcast{Add}(x, y), d)]"
assert str(Env(i2, [o2])) == "[Broadcast{Mul}(y, z)]"
# print g
# for i, o in find_cliques(g):
# print "-->", Env(i, [o])
def test_broadcasting(self):
x, y, z = inputs([0]*1, [0]*2, [0]*3)
e = x + y + z
g = Env([x, y, z], [e])
lift_dimshuffle.optimize(g)
assert len(find_cliques(g, through_broadcast = True)) == 1
assert len(find_cliques(g, through_broadcast = False)) == 2
# print g
# for i, o in find_cliques(g, True):
# print "-->", Env(i, [o])
if __name__ == '__main__':
unittest.main()
from gof import opt from gof import opt
from elemwise import Broadcast from elemwise import Broadcast, DimShuffle
from gof.python25 import any, all
class InplaceOptimizer(opt.OpSpecificOptimizer): class InplaceOptimizer(opt.OpSpecificOptimizer):
...@@ -26,14 +27,122 @@ class InplaceOptimizer(opt.OpSpecificOptimizer): ...@@ -26,14 +27,122 @@ class InplaceOptimizer(opt.OpSpecificOptimizer):
inplace_optimizer = InplaceOptimizer() inplace_optimizer = InplaceOptimizer()
class DimShuffleLifter(opt.Optimizer):
"""
"Lifts" DimShuffle through Broadcast operations and merges
consecutive DimShuffles. Basically, applies the following
transformations on the whole graph:
DimShuffle(Broadcast(x, y)) => Broadcast(DimShuffle(x), DimShuffle(y))
DimShuffle(DimShuffle(x)) => DimShuffle(x)
After this transform, clusters of Broadcast operations are
void of DimShuffle operations.
"""
def apply(self, env):
seen = set()
def merge(ord1, ord2):
return [x == 'x' and 'x' or ord1[x] for x in ord2]
def lift(r):
if r in seen:
return
seen.add(r)
op = r.owner
if op is None \
or op in env.inputs \
or op in env.orphans():
return
if isinstance(op, DimShuffle):
in_op = op.inputs[0].owner
if isinstance(in_op, DimShuffle):
new_order = [x == 'x' and 'x' or in_op.new_order[x] for x in op.new_order]
if new_order == range(len(new_order)):
repl = in_op.inputs[0]
else:
repl = DimShuffle(in_op.inputs[0], new_order).out
env.replace(r, repl)
lift(repl)
return
elif isinstance(in_op, Broadcast):
repl = Broadcast(in_op.scalar_opclass,
[DimShuffle(input, op.new_order).out for input in in_op.inputs],
in_op.inplace_pattern).out
env.replace(r, repl)
r = repl
op = r.owner
for next_r in op.inputs:
lift(next_r)
for output in env.outputs:
lift(output)
lift_dimshuffle = DimShuffleLifter()
def find_cliques(env, through_broadcast = False):
def seek_from(r):
op = r.owner
if r in env.inputs \
or r in env.orphans() \
or op is None \
or not isinstance(op, Broadcast):
return None
ret = set()
if not through_broadcast:
if any(any(bc) and not all(bc)
for bc in zip(*[input.broadcastable for input in op.inputs])):
ret.update(op.inputs)
return ret
for input in op.inputs:
res = seek_from(input)
if res is None:
ret.add(input)
else:
ret.update(res)
return ret
cliques = []
def find_cliques_helper(r):
if r in env.inputs or r in env.orphans():
return
clique_inputs = seek_from(r)
if clique_inputs is None:
op = r.owner
if op is not None:
for input in op.inputs:
find_cliques_helper(input)
else:
cliques.append((clique_inputs, r))
for input in clique_inputs:
find_cliques_helper(input)
for output in env.outputs:
find_cliques_helper(output)
# todo: merge the cliques if possible
return cliques
# class ElemwisePatternOptimizer(opt.Optimizer): # class ElemwisePatternOptimizer(opt.Optimizer):
# def __init__(self, scalar_opt): # def __init__(self, scalar_opt):
# self. # self.
# def find_elemwise_cliques(env, cross_broadcast = False):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论