提交 88f7858f authored 作者: Olivier Breuleux's avatar Olivier Breuleux

more new elemwise stuff

上级 a32fcd55
......@@ -87,12 +87,13 @@ def broadcasting_cgen(op):
class DimShuffle(Op, Viewer):
def __init__(self, input, new_order):
def __init__(self, input, new_order, inplace = True):
ib = input.broadcastable
ob = []
for value in new_order:
if value == 'x':
self.has_x = True
ob.append(1)
else:
ob.append(ib[value])
......@@ -104,8 +105,34 @@ class DimShuffle(Op, Viewer):
self.inputs = input,
self.outputs = output,
self.inplace = inplace
self.numorder = [x for x in new_order if type(x) == int]
self.is_transposition = sorted(new_order) == range(length(ib))
self.dup_dims = len(set(self.numorder)) != len(self.numorder)
self.all_dims = len(set(self.numorder)) == len(ib)
if self.dup_dims or not self.all_dims:
raise NotImplementedError("You must provide a permutation of *all* the input dimensions with *no duplicates*.")
def view_map(self):
return {self.outputs[0]: [self.inputs[0]]}
if self.inplace:
return {self.outputs[0]: [self.inputs[0]]}
else:
return {}
def perform(self):
res = self.inputs[0].data.transpose(self.numorder)
shape = list(res.shape)
new_shape = []
for entry in new_order:
if entry == 'x':
new_shape.append(1)
else:
new_shape.append(shape.pop())
res = res.reshape(new_shape)
if not inplace:
res = numpy.copy(res)
self.outputs[0].data = res
def __str__(self):
return "%s(%s, %s)" % (self.__class__.__name__, str(self.inputs[0]), self.new_order)
......@@ -138,6 +165,7 @@ class Broadcast(Op, Destroyer):
self.inplace_pattern = inplace_pattern
self.scalar_opclass = scalar_opclass
self.shadow = scalar_opclass([Scalar(dtype = t.dtype) for t in self.inputs])
self.ufunc = numpy.frompyfunc(scalar_opclass.impl, scalar_opclass.nin, scalar_opclass.nout)
def id(self):
return (self.__class__, self.scalar_opclass, self.inplace_pattern)
......@@ -170,10 +198,32 @@ class Broadcast(Op, Destroyer):
else:
ret.append(r)
return ret
def broadcast2(op):
def perform(self):
output_storage = []
if not self.inplace_pattern:
for output in self.outputs:
odat = output.data
if odat is not None:
odat.resize(self.inputs[0].data.shape)
else:
odat = numpy.ndarray(self.inputs[0].data.shape, dtype = output.dtype)
output_storage.append(odat)
else:
for i, output in enumerate(self.outputs):
if i in self.inplace_pattern:
odat = self.inputs[self.inplace_pattern[i]].data
else:
odat = output.data
if odat is not None:
odat.resize(self.inputs[0].data.shape)
else:
odat = numpy.ndarray(self.inputs[0].data.shape, dtype = output.dtype)
output_storage.append(odat)
self.ufunc(*([input.data for input in self.inputs] + output_storage))
def broadcast(op):
def instantiate(*inputs):
target_length = max([len(input.broadcastable) for input in inputs])
args = []
......@@ -186,11 +236,48 @@ def broadcast2(op):
return op(*args)
class FoldX(Op):
def __init__(self, scalar_opclass, inputs, to_fold):
pass
class CAReduce(Op):
"""
CAReduce(scalar_op, inputs, dimensions_to_reduce = None, init = None, shortcut = False)
The number of inputs must be the difference between the number of
outputs of scalar_op and its number of inputs. CAReduce holds
scalar states, the accumulators, in proportion to the number of
outputs of scalar_op and it updates them iteratively:
for x, y, ... in input0, input1, ...
scalar_state <- scalar_op(scalar_state, x, y, ...)
The initial states are init if provided (they must be scalars),
else if there are as many states as inputs, a sample from each
input will be taken as initialization, else an error will be
raised.
If shortcut is True and the scalar op has a 'tbd' field, the
iteration will try to stop as soon as it encounters the value
specified for that field and will return it immediately, eg
multiply/and will return 0 at first sight of 0 and 'or' will
return 1 at first sight of 1.
In order to optimize memory usage patterns, CAReduce makes zero
guarantees on the order in which it iterates over the dimensions
and the elements of the array(s). Therefore, to ensure consistent
results, the scalar operation represented by the reduction must be
both commutative and associative (eg add, multiply, binary
or/and/xor - but not subtract, divide or power).
"""
def __init__(self, scalar_opclass, inputs, dimensions_to_reduce = None):
if scalar_opclass.nin != 2 or scalar_opclass.nout != 1:
raise NotImplementedError("CAReduce only supports binary functions with a single output.")
def reduce(op, dimensions_to_reduce):
if getattr(op, 'commutative', True) and getattr(op, 'associative', True):
reducer = CAReduce
else:
raise NotImplementedError("The scalar op class to reduce must be commutative and associative.")
def instantiate(*inputs):
return reducer(op, inputs, dimensions_to_reduce)
return instantiate
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论