提交 61d10d72 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

added ShuffleRule

上级 d2bc9167
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import elemwise_cgen as cgen import elemwise_cgen as cgen
import numpy import numpy
from gof import Op, Apply from gof import Op, Macro, Apply
import scalar import scalar
from scalar import Scalar from scalar import Scalar
import gof import gof
...@@ -29,6 +29,50 @@ def TensorConstant(*inputs, **kwargs): ...@@ -29,6 +29,50 @@ def TensorConstant(*inputs, **kwargs):
### DimShuffle ### ### DimShuffle ###
################## ##################
## TODO: rule-based version of DimShuffle
## would allow for Transpose, LComplete, RComplete, etc.
## Can be optimized into DimShuffle later on.
class ShuffleRule(Macro):
"""
ABSTRACT Op - it has no perform and no c_code
Apply ExpandMacros to this node to obtain
an equivalent DimShuffle which can be performed.
"""
def __init__(self, rule = None, name = None):
if rule is not None:
self.rule = rule
self.name = name
def make_node(self, input, *models):
pattern = self.rule(input.type.broadcastable, *(model.type.broadcastable for model in models))
return gof.Apply(self,
(input,) + models,
[Tensor(dtype = input.type.dtype,
broadcastable = [x == 'x' for x in pattern]).make_result()])
def expand(self, r):
input, models = r.owner.inputs[0], r.owner.inputs[1:]
new_order = self.rule(input.type.broadcastable, *(model.type.broadcastable for model in models))
return DimShuffle(input.type.broadcastable, new_order)(input)
def __eq__(self, other):
return type(self) == type(other) and self.rule == other.rule
def __hash__(self, other):
return hash(self.rule)
def __str__(self):
if self.name is not None:
return self.name
else:
return "ShuffleRule{%s}" % self.role
_transpose = ShuffleRule(rule = lambda input: range(len(input)-1, -1, -1),
name = 'transpose')
lcomplete = ShuffleRule(rule = lambda input, *models: ['x']*(max([0]+map(len,models))-len(input)) + range(len(input)),
name = 'lcomplete')
rcomplete = ShuffleRule(rule = lambda input, *models: range(len(input)) + ['x']*(max(map(len,models))-len(input)),
name = 'rcomplete')
class DimShuffle(Op): class DimShuffle(Op):
""" """
...@@ -182,6 +226,23 @@ class DimShuffle(Op): ...@@ -182,6 +226,23 @@ class DimShuffle(Op):
return DimShuffle(gz.type.broadcastable, grad_order)(gz), return DimShuffle(gz.type.broadcastable, grad_order)(gz),
# class LComplete(Op):
# view_map = {0: [0]}
# def make_node(self, x, y):
# x, y = map(as_tensor, (x, y))
# xd, yd = x.type.ndim, y.type.ndim
# if xd > yd:
# raise TypeError("The tensor to left-complete has more dimensions than the model.")
# return gof.Apply(self,
# [x, y],
# [Tensor(dtype = x.type.dtype,
# broadcastable = (True,)*(yd-xd) + x.type.broadcastable).make_result()])
# def perform(self, node, (x, y), (z, )):
# return x.reshape((1, )*(y.ndim - x.ndim) + tuple(x.shape))
# def grad(self, node, (x, ), (gz, )):
# xd, gzd = x.type.ndim, gz.type.ndim
# return DimShuffle(gz.broadcastable, range(gzd-xd, xd))(gz)
################ ################
### Elemwise ### ### Elemwise ###
...@@ -243,20 +304,23 @@ class Elemwise(Op): ...@@ -243,20 +304,23 @@ class Elemwise(Op):
shadow = self.scalar_op.make_node(*[Scalar(dtype = t.type.dtype)() for t in inputs]) shadow = self.scalar_op.make_node(*[Scalar(dtype = t.type.dtype)() for t in inputs])
target_length = max([input.type.ndim for input in inputs]) target_length = max([input.type.ndim for input in inputs])
args = [] if len(inputs) > 1:
for input in inputs: inputs = [lcomplete(input, *inputs) for input in inputs]
length = input.type.ndim # args = []
difference = target_length - length # for input in inputs:
if not difference: # length = input.type.ndim
args.append(input) # difference = target_length - length
else: # if not difference:
args.append(DimShuffle(range(length), ['x']*difference + range(length))(input)) # args.append(input)
inputs = args # else:
# # TODO: use LComplete instead
# try: # args.append(DimShuffle(input.type.broadcastable, ['x']*difference + range(length))(input))
# assert len(set([len(input.type.broadcastable) for input in inputs])) == 1 # inputs = args
# except (AssertionError, AttributeError):
# raise TypeError("All inputs to a Broadcast subclass must be Tensor instances and their broadcastable fields must all have the same length.", inputs) try:
assert len(set([len(input.type.broadcastable) for input in inputs])) == 1
except (AssertionError, AttributeError):
raise TypeError("All inputs to a Broadcast subclass must be Tensor instances and their broadcastable fields must all have the same length.", inputs)
out_broadcastables = [[all(bcast) for bcast in zip(*[input.type.broadcastable for input in inputs])]] * shadow.nout out_broadcastables = [[all(bcast) for bcast in zip(*[input.type.broadcastable for input in inputs])]] * shadow.nout
inplace_pattern = self.inplace_pattern inplace_pattern = self.inplace_pattern
...@@ -509,6 +573,12 @@ class CAReduce(Op): ...@@ -509,6 +573,12 @@ class CAReduce(Op):
broadcastable = [x for i, x in enumerate(input.type.broadcastable) if i not in axis])() broadcastable = [x for i, x in enumerate(input.type.broadcastable) if i not in axis])()
return Apply(self, [input], [output]) return Apply(self, [input], [output])
def __eq__(self, other):
return type(self) == type(other) and self.scalar_op == other.scalar_op and self.axis == other.axis
def __hash__(self):
return hash(self.scalar_op) ^ hash(self.axis)
def __str__(self): def __str__(self):
if self.axis is not None: if self.axis is not None:
return "Reduce{%s}{%s}" % (self.scalar_op, ", ".join(str(x) for x in self.axis)) return "Reduce{%s}{%s}" % (self.scalar_op, ", ".join(str(x) for x in self.axis))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论