提交 61d10d72 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

added ShuffleRule

上级 d2bc9167
......@@ -2,7 +2,7 @@
import elemwise_cgen as cgen
import numpy
from gof import Op, Apply
from gof import Op, Macro, Apply
import scalar
from scalar import Scalar
import gof
......@@ -29,6 +29,50 @@ def TensorConstant(*inputs, **kwargs):
### DimShuffle ###
##################
## TODO: rule-based version of DimShuffle
## would allow for Transpose, LComplete, RComplete, etc.
## Can be optimized into DimShuffle later on.
class ShuffleRule(Macro):
"""
ABSTRACT Op - it has no perform and no c_code
Apply ExpandMacros to this node to obtain
an equivalent DimShuffle which can be performed.
"""
def __init__(self, rule = None, name = None):
if rule is not None:
self.rule = rule
self.name = name
def make_node(self, input, *models):
pattern = self.rule(input.type.broadcastable, *(model.type.broadcastable for model in models))
return gof.Apply(self,
(input,) + models,
[Tensor(dtype = input.type.dtype,
broadcastable = [x == 'x' for x in pattern]).make_result()])
def expand(self, r):
input, models = r.owner.inputs[0], r.owner.inputs[1:]
new_order = self.rule(input.type.broadcastable, *(model.type.broadcastable for model in models))
return DimShuffle(input.type.broadcastable, new_order)(input)
def __eq__(self, other):
return type(self) == type(other) and self.rule == other.rule
def __hash__(self, other):
return hash(self.rule)
def __str__(self):
if self.name is not None:
return self.name
else:
return "ShuffleRule{%s}" % self.role
_transpose = ShuffleRule(rule = lambda input: range(len(input)-1, -1, -1),
name = 'transpose')
lcomplete = ShuffleRule(rule = lambda input, *models: ['x']*(max([0]+map(len,models))-len(input)) + range(len(input)),
name = 'lcomplete')
rcomplete = ShuffleRule(rule = lambda input, *models: range(len(input)) + ['x']*(max(map(len,models))-len(input)),
name = 'rcomplete')
class DimShuffle(Op):
"""
......@@ -182,6 +226,23 @@ class DimShuffle(Op):
return DimShuffle(gz.type.broadcastable, grad_order)(gz),
# class LComplete(Op):
# view_map = {0: [0]}
# def make_node(self, x, y):
# x, y = map(as_tensor, (x, y))
# xd, yd = x.type.ndim, y.type.ndim
# if xd > yd:
# raise TypeError("The tensor to left-complete has more dimensions than the model.")
# return gof.Apply(self,
# [x, y],
# [Tensor(dtype = x.type.dtype,
# broadcastable = (True,)*(yd-xd) + x.type.broadcastable).make_result()])
# def perform(self, node, (x, y), (z, )):
# return x.reshape((1, )*(y.ndim - x.ndim) + tuple(x.shape))
# def grad(self, node, (x, ), (gz, )):
# xd, gzd = x.type.ndim, gz.type.ndim
# return DimShuffle(gz.broadcastable, range(gzd-xd, xd))(gz)
################
### Elemwise ###
......@@ -243,20 +304,23 @@ class Elemwise(Op):
shadow = self.scalar_op.make_node(*[Scalar(dtype = t.type.dtype)() for t in inputs])
target_length = max([input.type.ndim for input in inputs])
args = []
for input in inputs:
length = input.type.ndim
difference = target_length - length
if not difference:
args.append(input)
else:
args.append(DimShuffle(range(length), ['x']*difference + range(length))(input))
inputs = args
# try:
# assert len(set([len(input.type.broadcastable) for input in inputs])) == 1
# except (AssertionError, AttributeError):
# raise TypeError("All inputs to a Broadcast subclass must be Tensor instances and their broadcastable fields must all have the same length.", inputs)
if len(inputs) > 1:
inputs = [lcomplete(input, *inputs) for input in inputs]
# args = []
# for input in inputs:
# length = input.type.ndim
# difference = target_length - length
# if not difference:
# args.append(input)
# else:
# # TODO: use LComplete instead
# args.append(DimShuffle(input.type.broadcastable, ['x']*difference + range(length))(input))
# inputs = args
try:
assert len(set([len(input.type.broadcastable) for input in inputs])) == 1
except (AssertionError, AttributeError):
raise TypeError("All inputs to a Broadcast subclass must be Tensor instances and their broadcastable fields must all have the same length.", inputs)
out_broadcastables = [[all(bcast) for bcast in zip(*[input.type.broadcastable for input in inputs])]] * shadow.nout
inplace_pattern = self.inplace_pattern
......@@ -508,7 +572,13 @@ class CAReduce(Op):
output = Tensor(dtype = input.type.dtype,
broadcastable = [x for i, x in enumerate(input.type.broadcastable) if i not in axis])()
return Apply(self, [input], [output])
def __eq__(self, other):
return type(self) == type(other) and self.scalar_op == other.scalar_op and self.axis == other.axis
def __hash__(self):
return hash(self.scalar_op) ^ hash(self.axis)
def __str__(self):
if self.axis is not None:
return "Reduce{%s}{%s}" % (self.scalar_op, ", ".join(str(x) for x in self.axis))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论