提交 552b1caa authored 作者: James Bergstra's avatar James Bergstra

added non-inplace version of SetSubtensor

上级 2dc20643
...@@ -1140,30 +1140,32 @@ class Subtensor(Op): ...@@ -1140,30 +1140,32 @@ class Subtensor(Op):
helper(idx) helper(idx)
return ret return ret
@staticmethod
def convert(entry, slice_ok=True):
scal_types = [scal.int64, scal.int32, scal.int16, scal.int8]
tensor_types = [bscalar, iscalar, lscalar]
if isinstance(entry, gof.Result) and entry.type in scal_types:
return entry.type
elif isinstance(entry, gof.Type) and entry in scal_types:
return entry
if isinstance(entry, gof.Result) and entry.type in tensor_types:
return scal.Scalar(entry.type.dtype)
elif isinstance(entry, gof.Type) and entry in tensor_types:
return scal.Scalar(entry.dtype)
elif slice_ok and isinstance(entry, slice):
a = entry.start
b = entry.stop
c = entry.step
return slice(Subtensor.convert(a, False) if a is not None else None,
Subtensor.convert(b, False) if b is not None else None,
Subtensor.convert(c, False) if c is not None else None)
elif isinstance(entry, int):
return entry
else:
raise TypeError(Subtensor.e_indextype, entry)
def __init__(self, idx_list): def __init__(self, idx_list):
def convert(entry, slice_ok=True): self.idx_list = map(self.convert, idx_list)
scal_types = [scal.int64, scal.int32, scal.int16, scal.int8]
tensor_types = [bscalar, iscalar, lscalar]
if isinstance(entry, gof.Result) and entry.type in scal_types:
return entry.type
elif isinstance(entry, gof.Type) and entry in scal_types:
return entry
if isinstance(entry, gof.Result) and entry.type in tensor_types:
return scal.Scalar(entry.type.dtype)
elif isinstance(entry, gof.Type) and entry in tensor_types:
return scal.Scalar(entry.dtype)
elif slice_ok and isinstance(entry, slice):
a = entry.start
b = entry.stop
c = entry.step
return slice(convert(a, False) if a is not None else None,
convert(b, False) if b is not None else None,
convert(c, False) if c is not None else None)
elif isinstance(entry, int):
return entry
else:
raise TypeError(Subtensor.e_indextype, entry)
self.idx_list = map(convert, idx_list)
def make_node(self, x, *inputs): def make_node(self, x, *inputs):
x = as_tensor(x) x = as_tensor(x)
...@@ -1269,10 +1271,42 @@ pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, Subtensor), S ...@@ -1269,10 +1271,42 @@ pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, Subtensor), S
class SetSubtensor(Subtensor): class SetSubtensor(Op):
"""WRITEME""" """Set just some elements of a larger Tensor.
view_map = {}
destroy_map = {0: [0]} This is like numpy's
z[i,j,k] = <something>
"""
def __init__(self, idx_list, inplace=False):
self.idx_list = map(Subtensor.convert, idx_list)
self.inplace = inplace
if inplace:
self.destroy_map = {0: [0]}
def __eq__(self, other):
return type(self) == type(other) \
and self.idx_list == other.idx_list \
and self.inplace == other.inplace
def __hash__(self):
idx_list = tuple((entry.start, entry.stop, entry.step)
if isinstance(entry, slice)
else entry
for entry in self.idx_list)
return hash(type(self)) ^ hash(idx_list) ^ hash(self.inplace)
def __str__(self):
indices = []
for entry in self.idx_list:
if isinstance(entry, slice):
indices.append(":".join("" if x is None else str(x) for x in [entry.start, entry.stop, entry.step]))
else:
indices.append(str(entry))
return "%s%s{%s}" % ('Inplace' if self.inplace else '',
self.__class__.__name__, ", ".join(indices))
def make_node(self, x, y, *inputs): def make_node(self, x, y, *inputs):
x, y = map(as_tensor, [x, y]) x, y = map(as_tensor, [x, y])
...@@ -1318,6 +1352,8 @@ class SetSubtensor(Subtensor): ...@@ -1318,6 +1352,8 @@ class SetSubtensor(Subtensor):
cdata = tuple(map(convert, self.idx_list)) cdata = tuple(map(convert, self.idx_list))
if len(cdata) == 1: if len(cdata) == 1:
cdata = cdata[0] cdata = cdata[0]
if not self.inplace:
x = x.copy()
x.__setitem__(cdata, y) x.__setitem__(cdata, y)
out[0] = x out[0] = x
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
from .. import gof from .. import gof
from ..gof import opt, InconsistencyError from ..gof import opt, InconsistencyError, TopoOptimizer
from elemwise import Elemwise, DimShuffle from elemwise import Elemwise, DimShuffle
from .. import scalar from .. import scalar
import basic as T import basic as T
...@@ -273,6 +273,17 @@ def local_subtensor_make_vector(node): ...@@ -273,6 +273,17 @@ def local_subtensor_make_vector(node):
register_canonicalize(local_subtensor_make_vector) register_canonicalize(local_subtensor_make_vector)
#after priority 50 Destructive inplace operations
#gemm is the first one now, at priority 70
@gof.local_optimizer([None])
def local_inplace_setsubtensor(node):
if isinstance(node.op, T.SetSubtensor) and not node.op.inplace:
new_op = T.SetSubtensor(node.op.idx_list, inplace=True)
new_node = new_op(*node.inputs)
return new_node.outputs
return False
compile.optdb.register('inplace_setsubtensor', TopoOptimizer(local_inplace_setsubtensor), 60, 'fast_run', 'inplace') #DEBUG
################## ##################
# Middleman cuts # # Middleman cuts #
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论