提交 552b1caa authored 作者: James Bergstra's avatar James Bergstra

added non-inplace version of SetSubtensor

上级 2dc20643
...@@ -1140,7 +1140,7 @@ class Subtensor(Op): ...@@ -1140,7 +1140,7 @@ class Subtensor(Op):
helper(idx) helper(idx)
return ret return ret
def __init__(self, idx_list): @staticmethod
def convert(entry, slice_ok=True): def convert(entry, slice_ok=True):
scal_types = [scal.int64, scal.int32, scal.int16, scal.int8] scal_types = [scal.int64, scal.int32, scal.int16, scal.int8]
tensor_types = [bscalar, iscalar, lscalar] tensor_types = [bscalar, iscalar, lscalar]
...@@ -1156,14 +1156,16 @@ class Subtensor(Op): ...@@ -1156,14 +1156,16 @@ class Subtensor(Op):
a = entry.start a = entry.start
b = entry.stop b = entry.stop
c = entry.step c = entry.step
return slice(convert(a, False) if a is not None else None, return slice(Subtensor.convert(a, False) if a is not None else None,
convert(b, False) if b is not None else None, Subtensor.convert(b, False) if b is not None else None,
convert(c, False) if c is not None else None) Subtensor.convert(c, False) if c is not None else None)
elif isinstance(entry, int): elif isinstance(entry, int):
return entry return entry
else: else:
raise TypeError(Subtensor.e_indextype, entry) raise TypeError(Subtensor.e_indextype, entry)
self.idx_list = map(convert, idx_list)
def __init__(self, idx_list):
self.idx_list = map(self.convert, idx_list)
def make_node(self, x, *inputs): def make_node(self, x, *inputs):
x = as_tensor(x) x = as_tensor(x)
...@@ -1269,10 +1271,42 @@ pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, Subtensor), S ...@@ -1269,10 +1271,42 @@ pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, Subtensor), S
class SetSubtensor(Subtensor): class SetSubtensor(Op):
"""WRITEME""" """Set just some elements of a larger Tensor.
view_map = {}
destroy_map = {0: [0]} This is like numpy's
z[i,j,k] = <something>
"""
def __init__(self, idx_list, inplace=False):
self.idx_list = map(Subtensor.convert, idx_list)
self.inplace = inplace
if inplace:
self.destroy_map = {0: [0]}
def __eq__(self, other):
return type(self) == type(other) \
and self.idx_list == other.idx_list \
and self.inplace == other.inplace
def __hash__(self):
idx_list = tuple((entry.start, entry.stop, entry.step)
if isinstance(entry, slice)
else entry
for entry in self.idx_list)
return hash(type(self)) ^ hash(idx_list) ^ hash(self.inplace)
def __str__(self):
indices = []
for entry in self.idx_list:
if isinstance(entry, slice):
indices.append(":".join("" if x is None else str(x) for x in [entry.start, entry.stop, entry.step]))
else:
indices.append(str(entry))
return "%s%s{%s}" % ('Inplace' if self.inplace else '',
self.__class__.__name__, ", ".join(indices))
def make_node(self, x, y, *inputs): def make_node(self, x, y, *inputs):
x, y = map(as_tensor, [x, y]) x, y = map(as_tensor, [x, y])
...@@ -1318,6 +1352,8 @@ class SetSubtensor(Subtensor): ...@@ -1318,6 +1352,8 @@ class SetSubtensor(Subtensor):
cdata = tuple(map(convert, self.idx_list)) cdata = tuple(map(convert, self.idx_list))
if len(cdata) == 1: if len(cdata) == 1:
cdata = cdata[0] cdata = cdata[0]
if not self.inplace:
x = x.copy()
x.__setitem__(cdata, y) x.__setitem__(cdata, y)
out[0] = x out[0] = x
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
from .. import gof from .. import gof
from ..gof import opt, InconsistencyError from ..gof import opt, InconsistencyError, TopoOptimizer
from elemwise import Elemwise, DimShuffle from elemwise import Elemwise, DimShuffle
from .. import scalar from .. import scalar
import basic as T import basic as T
...@@ -273,6 +273,17 @@ def local_subtensor_make_vector(node): ...@@ -273,6 +273,17 @@ def local_subtensor_make_vector(node):
register_canonicalize(local_subtensor_make_vector) register_canonicalize(local_subtensor_make_vector)
#after priority 50 Destructive inplace operations
#gemm is the first one now, at priority 70
@gof.local_optimizer([None])
def local_inplace_setsubtensor(node):
if isinstance(node.op, T.SetSubtensor) and not node.op.inplace:
new_op = T.SetSubtensor(node.op.idx_list, inplace=True)
new_node = new_op(*node.inputs)
return new_node.outputs
return False
compile.optdb.register('inplace_setsubtensor', TopoOptimizer(local_inplace_setsubtensor), 60, 'fast_run', 'inplace') #DEBUG
################## ##################
# Middleman cuts # # Middleman cuts #
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论