提交 6ed25a61 authored 作者: Frederic's avatar Frederic

[BUG, CRASH] Correctly set the view_map for TypedList.

This was detected by a crash in DebugMode in the Theano buildbot as another change made Type.may_share_memory obligatory in DebugMode. As this discovered a bug, I'll keep this mandatory.
上级 06afae25
from type import TypedListType
import copy
import numpy
from type import TypedListType
import theano
from theano.gof import Apply, Constant, Op, Variable
from theano.tensor.type_other import SliceType
from theano import tensor as T
import numpy
from theano.compile.debugmode import _lessbroken_deepcopy
class _typed_list_py_operators:
......@@ -51,6 +53,8 @@ TypedListType.Variable = TypedListVariable
class GetItem(Op):
# See doc in instance of this Op or function after this class definition.
view_map = {0: [0]}
def __eq__(self, other):
return type(self) == type(other)
......@@ -112,6 +116,13 @@ class Append(Op):
self.inplace = inplace
if self.inplace:
self.destroy_map = {0: [0]}
# TODO: make destroy_handler support having views and
# destroyed version of multiple inputs.
# self.view_map = {0: [1]}
else:
# TODO: make destroy_handler support multiple view
# self.view_map = {0: [0, 1]}
self.view_map = {0: [0]}
def __eq__(self, other):
return type(self) == type(other) and self.inplace == other.inplace
......@@ -129,12 +140,15 @@ class Append(Op):
out[0] = list(x)
else:
out[0] = x
# need to copy toAppend due to destroy_handler limitation
toAppend = _lessbroken_deepcopy(toAppend)
out[0].append(toAppend)
def __str__(self):
return self.__class__.__name__
def c_code(self, node, name, inp, out, sub):
# DISABLED AS WE NEED TO UPDATE IT TO COPY toAppend().
def _c_code_(self, node, name, inp, out, sub):
x_name, toAppend = inp[0], inp[1]
output_name = out[0]
fail = sub['fail']
......@@ -174,6 +188,13 @@ class Extend(Op):
self.inplace = inplace
if self.inplace:
self.destroy_map = {0: [0]}
# TODO: make destroy_handler support having views and
# destroyed version of multiple inputs.
# self.view_map = {0: [1]}
else:
# TODO: make destroy_handler support multiple view
# self.view_map = {0: [0, 1]}
self.view_map = {0: [0]}
def __eq__(self, other):
return type(self) == type(other) and self.inplace == other.inplace
......@@ -191,12 +212,17 @@ class Extend(Op):
out[0] = list(x)
else:
out[0] = x
out[0].extend(toAppend)
# need to copy toAppend due to destroy_handler limitation
if toAppend:
o = out[0]
for i in toAppend:
o.append(_lessbroken_deepcopy(i))
def __str__(self):
return self.__class__.__name__
def c_code(self, node, name, inp, out, sub):
# DISABLED AS WE NEED TO UPDATE IT TO COPY toAppend().
def _c_code_(self, node, name, inp, out, sub):
x_name, toAppend = inp[0], inp[1]
output_name = out[0]
fail = sub['fail']
......@@ -222,7 +248,7 @@ class Extend(Op):
Py_INCREF(%(output_name)s);
""" % locals()
def c_code_cache_version(self):
def c_code_cache_version_(self):
return (1,)
extend = Extend()
......@@ -240,6 +266,13 @@ class Insert(Op):
self.inplace = inplace
if self.inplace:
self.destroy_map = {0: [0]}
# TODO: make destroy_handler support having views and
# destroyed version of multiple inputs.
# self.view_map = {0: [2]}
else:
# TODO: make destroy_handler support multiple view
# self.view_map = {0: [0, 2]}
self.view_map = {0: [0]}
def __eq__(self, other):
return type(self) == type(other) and self.inplace == other.inplace
......@@ -262,12 +295,15 @@ class Insert(Op):
out[0] = list(x)
else:
out[0] = x
# need to copy toAppend due to destroy_handler limitation
toInsert = _lessbroken_deepcopy(toInsert)
out[0].insert(index, toInsert)
def __str__(self):
return self.__class__.__name__
def c_code(self, node, name, inp, out, sub):
# DISABLED AS WE NEED TO UPDATE IT TO COPY toAppend().
def _c_code_(self, node, name, inp, out, sub):
x_name, index, toInsert = inp[0], inp[1], inp[2]
output_name = out[0]
fail = sub['fail']
......@@ -308,6 +344,8 @@ class Remove(Op):
self.inplace = inplace
if self.inplace:
self.destroy_map = {0: [0]}
else:
self.view_map = {0: [0]}
def __eq__(self, other):
return type(self) == type(other) and self.inplace == other.inplace
......@@ -360,6 +398,8 @@ class Reverse(Op):
self.inplace = inplace
if self.inplace:
self.destroy_map = {0: [0]}
else:
self.view_map = {0: [0]}
def __eq__(self, other):
return type(self) == type(other) and self.inplace == other.inplace
......
......@@ -76,6 +76,14 @@ class TypedListType(gof.Type):
return True
def may_share_memory(self, a, b):
if a is b:
return True
for idx1 in range(len(a)):
for idx2 in range(len(b)):
if self.ttype.may_share_memory(a[idx1], b[idx2]):
return True
def c_declare(self, name, sub, check_input=True):
return """
PyListObject* %(name)s;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论