提交 4549f839 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

pulled ShuffleRule

上级 071e7535
......@@ -29,59 +29,6 @@ def TensorConstant(*inputs, **kwargs):
### DimShuffle ###
##################
class ShuffleRule(Macro):
"""
ABSTRACT Op - it has no perform and no c_code
Apply ExpandMacros to this node to obtain
an equivalent DimShuffle which can be performed.
"""
level = 1
def __init__(self, rule = None, inplace = False, name = None):
if rule is not None:
self.rule = rule
self.inplace = inplace
if inplace:
self.view_map = {0: [0]}
self.name = name
def make_node(self, input, *models):
pattern = self.rule(input.type.broadcastable, *(model.type.broadcastable for model in models))
ib = input.type.broadcastable
return gof.Apply(self,
(input,) + models,
[Tensor(dtype = input.type.dtype,
broadcastable = [x == 'x' or ib[x] for x in pattern]).make_result()])
def expand(self, node):
input, models = node.inputs[0], node.inputs[1:]
new_order = self.rule(input.type.broadcastable, *(model.type.broadcastable for model in models))
#print new_order, node.outputs[0].type, DimShuffle(input.type.broadcastable, new_order)(input).type, node.outputs[0].type == DimShuffle(input.type.broadcastable, new_order)(input).type
if list(new_order) == range(input.type.ndim) and self.inplace:
return [input]
else:
return [DimShuffle(input.type.broadcastable, new_order, self.inplace)(input)]
def __eq__(self, other):
return type(self) == type(other) and self.rule == other.rule
def __hash__(self, other):
return hash(self.rule)
def __str__(self):
if self.name is not None:
return self.name
else:
return "ShuffleRule{%s}" % self.role
_transpose = ShuffleRule(rule = lambda input: range(len(input)-1, -1, -1),
inplace = True,
name = 'transpose')
lcomplete = ShuffleRule(rule = lambda input, *models: ['x']*(max([0]+map(len,models))-len(input)) + range(len(input)),
inplace = True,
name = 'lcomplete')
rcomplete = ShuffleRule(rule = lambda input, *models: range(len(input)) + ['x']*(max(map(len,models))-len(input)),
inplace = True,
name = 'rcomplete')
class DimShuffle(Op):
"""
Allows to reorder the dimensions of a tensor or insert or remove
......@@ -234,24 +181,6 @@ class DimShuffle(Op):
return DimShuffle(gz.type.broadcastable, grad_order)(gz),
# class LComplete(Op):
# view_map = {0: [0]}
# def make_node(self, x, y):
# x, y = map(as_tensor, (x, y))
# xd, yd = x.type.ndim, y.type.ndim
# if xd > yd:
# raise TypeError("The tensor to left-complete has more dimensions than the model.")
# return gof.Apply(self,
# [x, y],
# [Tensor(dtype = x.type.dtype,
# broadcastable = (True,)*(yd-xd) + x.type.broadcastable).make_result()])
# def perform(self, node, (x, y), (z, )):
# return x.reshape((1, )*(y.ndim - x.ndim) + tuple(x.shape))
# def grad(self, node, (x, ), (gz, )):
# xd, gzd = x.type.ndim, gz.type.ndim
# return DimShuffle(gz.broadcastable, range(gzd-xd, xd))(gz)
################
### Elemwise ###
################
......@@ -313,9 +242,6 @@ class Elemwise(Op):
target_length = max([input.type.ndim for input in inputs])
# if len(inputs) > 1:
# inputs = [lcomplete(input, *inputs) for input in inputs]
args = []
for input in inputs:
length = input.type.ndim
......@@ -326,7 +252,8 @@ class Elemwise(Op):
# TODO: use LComplete instead
args.append(DimShuffle(input.type.broadcastable, ['x']*difference + range(length), inplace = True)(input))
inputs = args
# # Following conditions should always be true?
# try:
# assert len(set([len(input.type.broadcastable) for input in inputs])) == 1
# except (AssertionError, AttributeError):
......
......@@ -81,9 +81,6 @@ class DimShufflePrinter:
def process(self, r, pstate):
if r.owner is None:
raise TypeError("Can only print DimShuffle.")
elif isinstance(r.owner.op, T.ShuffleRule):
new_r = r.owner.op.expand(r.owner)[0]
return pstate.pprinter.process(new_r, pstate)
elif isinstance(r.owner.op, T.DimShuffle):
ord = r.owner.op.new_order
return self.__p(ord, pstate, r.owner.inputs[0])
......@@ -174,7 +171,6 @@ def make_default_pp():
pp.assign(T.Sum(), FunctionPrinter('sum'))
pp.assign(T.grad, FunctionPrinter('d'))
pp.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, T.DimShuffle), DimShufflePrinter())
pp.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, T.ShuffleRule), DimShufflePrinter())
return pp
pp = make_default_pp()
......
......@@ -18,7 +18,7 @@ from gof.python25 import partial
### set up the external interface
from elemwise import Elemwise, ShuffleRule, DimShuffle, CAReduce, Sum
from elemwise import Elemwise, DimShuffle, CAReduce, Sum
import tensor_random as random
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论