提交 552b1caa authored 作者: James Bergstra's avatar James Bergstra

added non-inplace version of SetSubtensor

上级 2dc20643
......@@ -1140,7 +1140,7 @@ class Subtensor(Op):
helper(idx)
return ret
def __init__(self, idx_list):
@staticmethod
def convert(entry, slice_ok=True):
scal_types = [scal.int64, scal.int32, scal.int16, scal.int8]
tensor_types = [bscalar, iscalar, lscalar]
......@@ -1156,14 +1156,16 @@ class Subtensor(Op):
a = entry.start
b = entry.stop
c = entry.step
return slice(convert(a, False) if a is not None else None,
convert(b, False) if b is not None else None,
convert(c, False) if c is not None else None)
return slice(Subtensor.convert(a, False) if a is not None else None,
Subtensor.convert(b, False) if b is not None else None,
Subtensor.convert(c, False) if c is not None else None)
elif isinstance(entry, int):
return entry
else:
raise TypeError(Subtensor.e_indextype, entry)
self.idx_list = map(convert, idx_list)
def __init__(self, idx_list):
self.idx_list = map(self.convert, idx_list)
def make_node(self, x, *inputs):
x = as_tensor(x)
......@@ -1269,10 +1271,42 @@ pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, Subtensor), S
class SetSubtensor(Subtensor):
"""WRITEME"""
view_map = {}
destroy_map = {0: [0]}
class SetSubtensor(Op):
"""Set just some elements of a larger Tensor.
This is like numpy's
z[i,j,k] = <something>
"""
def __init__(self, idx_list, inplace=False):
self.idx_list = map(Subtensor.convert, idx_list)
self.inplace = inplace
if inplace:
self.destroy_map = {0: [0]}
def __eq__(self, other):
return type(self) == type(other) \
and self.idx_list == other.idx_list \
and self.inplace == other.inplace
def __hash__(self):
idx_list = tuple((entry.start, entry.stop, entry.step)
if isinstance(entry, slice)
else entry
for entry in self.idx_list)
return hash(type(self)) ^ hash(idx_list) ^ hash(self.inplace)
def __str__(self):
indices = []
for entry in self.idx_list:
if isinstance(entry, slice):
indices.append(":".join("" if x is None else str(x) for x in [entry.start, entry.stop, entry.step]))
else:
indices.append(str(entry))
return "%s%s{%s}" % ('Inplace' if self.inplace else '',
self.__class__.__name__, ", ".join(indices))
def make_node(self, x, y, *inputs):
x, y = map(as_tensor, [x, y])
......@@ -1318,6 +1352,8 @@ class SetSubtensor(Subtensor):
cdata = tuple(map(convert, self.idx_list))
if len(cdata) == 1:
cdata = cdata[0]
if not self.inplace:
x = x.copy()
x.__setitem__(cdata, y)
out[0] = x
......
......@@ -5,7 +5,7 @@
from .. import gof
from ..gof import opt, InconsistencyError
from ..gof import opt, InconsistencyError, TopoOptimizer
from elemwise import Elemwise, DimShuffle
from .. import scalar
import basic as T
......@@ -273,6 +273,17 @@ def local_subtensor_make_vector(node):
register_canonicalize(local_subtensor_make_vector)
#after priority 50 Destructive inplace operations
#gemm is the first one now, at priority 70
@gof.local_optimizer([None])
def local_inplace_setsubtensor(node):
if isinstance(node.op, T.SetSubtensor) and not node.op.inplace:
new_op = T.SetSubtensor(node.op.idx_list, inplace=True)
new_node = new_op(*node.inputs)
return new_node.outputs
return False
compile.optdb.register('inplace_setsubtensor', TopoOptimizer(local_inplace_setsubtensor), 60, 'fast_run', 'inplace') #DEBUG
##################
# Middleman cuts #
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论