提交 6ed25a61 authored 作者: Frederic's avatar Frederic

[BUG, CRASH] Correctly set the view_map for TypedList.

This was detected by a crash in DebugMode in the Theano buildbot as another change made Type.may_share_memory obligatory in DebugMode. As this discovered a bug, I'll keep this mandatory.
上级 06afae25
from type import TypedListType import copy
import numpy
from type import TypedListType
import theano import theano
from theano.gof import Apply, Constant, Op, Variable from theano.gof import Apply, Constant, Op, Variable
from theano.tensor.type_other import SliceType from theano.tensor.type_other import SliceType
from theano import tensor as T from theano import tensor as T
from theano.compile.debugmode import _lessbroken_deepcopy
import numpy
class _typed_list_py_operators: class _typed_list_py_operators:
...@@ -51,6 +53,8 @@ TypedListType.Variable = TypedListVariable ...@@ -51,6 +53,8 @@ TypedListType.Variable = TypedListVariable
class GetItem(Op): class GetItem(Op):
# See doc in instance of this Op or function after this class definition. # See doc in instance of this Op or function after this class definition.
view_map = {0: [0]}
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) return type(self) == type(other)
...@@ -112,6 +116,13 @@ class Append(Op): ...@@ -112,6 +116,13 @@ class Append(Op):
self.inplace = inplace self.inplace = inplace
if self.inplace: if self.inplace:
self.destroy_map = {0: [0]} self.destroy_map = {0: [0]}
# TODO: make destroy_handler support having views and
# destroyed version of multiple inputs.
# self.view_map = {0: [1]}
else:
# TODO: make destroy_handler support multiple view
# self.view_map = {0: [0, 1]}
self.view_map = {0: [0]}
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) and self.inplace == other.inplace return type(self) == type(other) and self.inplace == other.inplace
...@@ -129,12 +140,15 @@ class Append(Op): ...@@ -129,12 +140,15 @@ class Append(Op):
out[0] = list(x) out[0] = list(x)
else: else:
out[0] = x out[0] = x
# need to copy toAppend due to destroy_handler limitation
toAppend = _lessbroken_deepcopy(toAppend)
out[0].append(toAppend) out[0].append(toAppend)
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
def c_code(self, node, name, inp, out, sub): # DISABLED AS WE NEED TO UPDATE IT TO COPY toAppend().
def _c_code_(self, node, name, inp, out, sub):
x_name, toAppend = inp[0], inp[1] x_name, toAppend = inp[0], inp[1]
output_name = out[0] output_name = out[0]
fail = sub['fail'] fail = sub['fail']
...@@ -174,6 +188,13 @@ class Extend(Op): ...@@ -174,6 +188,13 @@ class Extend(Op):
self.inplace = inplace self.inplace = inplace
if self.inplace: if self.inplace:
self.destroy_map = {0: [0]} self.destroy_map = {0: [0]}
# TODO: make destroy_handler support having views and
# destroyed version of multiple inputs.
# self.view_map = {0: [1]}
else:
# TODO: make destroy_handler support multiple view
# self.view_map = {0: [0, 1]}
self.view_map = {0: [0]}
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) and self.inplace == other.inplace return type(self) == type(other) and self.inplace == other.inplace
...@@ -191,12 +212,17 @@ class Extend(Op): ...@@ -191,12 +212,17 @@ class Extend(Op):
out[0] = list(x) out[0] = list(x)
else: else:
out[0] = x out[0] = x
out[0].extend(toAppend) # need to copy toAppend due to destroy_handler limitation
if toAppend:
o = out[0]
for i in toAppend:
o.append(_lessbroken_deepcopy(i))
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
def c_code(self, node, name, inp, out, sub): # DISABLED AS WE NEED TO UPDATE IT TO COPY toAppend().
def _c_code_(self, node, name, inp, out, sub):
x_name, toAppend = inp[0], inp[1] x_name, toAppend = inp[0], inp[1]
output_name = out[0] output_name = out[0]
fail = sub['fail'] fail = sub['fail']
...@@ -222,7 +248,7 @@ class Extend(Op): ...@@ -222,7 +248,7 @@ class Extend(Op):
Py_INCREF(%(output_name)s); Py_INCREF(%(output_name)s);
""" % locals() """ % locals()
def c_code_cache_version(self): def c_code_cache_version_(self):
return (1,) return (1,)
extend = Extend() extend = Extend()
...@@ -240,6 +266,13 @@ class Insert(Op): ...@@ -240,6 +266,13 @@ class Insert(Op):
self.inplace = inplace self.inplace = inplace
if self.inplace: if self.inplace:
self.destroy_map = {0: [0]} self.destroy_map = {0: [0]}
# TODO: make destroy_handler support having views and
# destroyed version of multiple inputs.
# self.view_map = {0: [2]}
else:
# TODO: make destroy_handler support multiple view
# self.view_map = {0: [0, 2]}
self.view_map = {0: [0]}
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) and self.inplace == other.inplace return type(self) == type(other) and self.inplace == other.inplace
...@@ -262,12 +295,15 @@ class Insert(Op): ...@@ -262,12 +295,15 @@ class Insert(Op):
out[0] = list(x) out[0] = list(x)
else: else:
out[0] = x out[0] = x
# need to copy toAppend due to destroy_handler limitation
toInsert = _lessbroken_deepcopy(toInsert)
out[0].insert(index, toInsert) out[0].insert(index, toInsert)
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
def c_code(self, node, name, inp, out, sub): # DISABLED AS WE NEED TO UPDATE IT TO COPY toAppend().
def _c_code_(self, node, name, inp, out, sub):
x_name, index, toInsert = inp[0], inp[1], inp[2] x_name, index, toInsert = inp[0], inp[1], inp[2]
output_name = out[0] output_name = out[0]
fail = sub['fail'] fail = sub['fail']
...@@ -308,6 +344,8 @@ class Remove(Op): ...@@ -308,6 +344,8 @@ class Remove(Op):
self.inplace = inplace self.inplace = inplace
if self.inplace: if self.inplace:
self.destroy_map = {0: [0]} self.destroy_map = {0: [0]}
else:
self.view_map = {0: [0]}
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) and self.inplace == other.inplace return type(self) == type(other) and self.inplace == other.inplace
...@@ -360,6 +398,8 @@ class Reverse(Op): ...@@ -360,6 +398,8 @@ class Reverse(Op):
self.inplace = inplace self.inplace = inplace
if self.inplace: if self.inplace:
self.destroy_map = {0: [0]} self.destroy_map = {0: [0]}
else:
self.view_map = {0: [0]}
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) and self.inplace == other.inplace return type(self) == type(other) and self.inplace == other.inplace
......
...@@ -76,6 +76,14 @@ class TypedListType(gof.Type): ...@@ -76,6 +76,14 @@ class TypedListType(gof.Type):
return True return True
def may_share_memory(self, a, b):
if a is b:
return True
for idx1 in range(len(a)):
for idx2 in range(len(b)):
if self.ttype.may_share_memory(a[idx1], b[idx2]):
return True
def c_declare(self, name, sub, check_input=True): def c_declare(self, name, sub, check_input=True):
return """ return """
PyListObject* %(name)s; PyListObject* %(name)s;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论