提交 5deab31d authored 作者: Olivier Breuleux's avatar Olivier Breuleux

pre-cleanup klass

上级 ac62de29
...@@ -112,6 +112,8 @@ class Linker(object): ...@@ -112,6 +112,8 @@ class Linker(object):
class Container(object): class Container(object):
def __init__(self, r, storage, readonly = False, strict = False, name = None): def __init__(self, r, storage, readonly = False, strict = False, name = None):
if not isinstance(storage, list) or not len(storage) >= 1:
raise TypeError("storage must be a list of length at least one")
#self.r = r #self.r = r
if isinstance(r, Type): if isinstance(r, Type):
self.type = r self.type = r
...@@ -127,6 +129,9 @@ class Container(object): ...@@ -127,6 +129,9 @@ class Container(object):
if self.readonly: if self.readonly:
raise Exception("Cannot set readonly storage: %s" % self.name) raise Exception("Cannot set readonly storage: %s" % self.name)
try: try:
if value is None:
self.storage[0] = None
return
if self.strict: if self.strict:
self.storage[0] = self.type.filter(value, strict = True) self.storage[0] = self.type.filter(value, strict = True)
else: else:
......
差异被折叠。
...@@ -1212,7 +1212,7 @@ class Subtensor(Op): ...@@ -1212,7 +1212,7 @@ class Subtensor(Op):
def __init__(self, idx_list): def __init__(self, idx_list):
def convert(entry, slice_ok=True): def convert(entry, slice_ok=True):
scal_types =[scal.int64, scal.int32, scal.int16, scal.int8] scal_types = [scal.int64, scal.int32, scal.int16, scal.int8]
tensor_types = [bscalar, iscalar, lscalar] tensor_types = [bscalar, iscalar, lscalar]
if isinstance(entry, gof.Result) and entry.type in scal_types: if isinstance(entry, gof.Result) and entry.type in scal_types:
return entry.type return entry.type
...@@ -2059,7 +2059,7 @@ def grad(cost, wrt, g_cost=None): ...@@ -2059,7 +2059,7 @@ def grad(cost, wrt, g_cost=None):
Tensor(dtype = p.type.dtype, broadcastable = []), Tensor(dtype = p.type.dtype, broadcastable = []),
numpy.asarray(0, dtype=p.type.dtype)) numpy.asarray(0, dtype=p.type.dtype))
if isinstance(wrt, list): if isinstance(wrt, (list, tuple)):
return [gmap.get(p, zero(p)) for p in wrt] return [gmap.get(p, zero(p)) for p in wrt]
else: else:
return gmap.get(wrt, zero(wrt)) return gmap.get(wrt, zero(wrt))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论