提交 d18eacd8 authored 作者: James Bergstra's avatar James Bergstra

merged

...@@ -156,7 +156,7 @@ class DestroyHandlerHelper2(toolbox.Bookkeeper): ...@@ -156,7 +156,7 @@ class DestroyHandlerHelper2(toolbox.Bookkeeper):
self.debug_all_apps.remove(app) self.debug_all_apps.remove(app)
#UPDATE self.clients #UPDATE self.clients
for i, input in enumerate(app.inputs): for i, input in enumerate(set(app.inputs)):
del self.clients[input][app] del self.clients[input][app]
if getattr(app.op, 'destroy_map', {}): if getattr(app.op, 'destroy_map', {}):
......
...@@ -227,6 +227,7 @@ class Env(utils.object2): ...@@ -227,6 +227,7 @@ class Env(utils.object2):
For each feature that has a 'on_change_input' method, calls: For each feature that has a 'on_change_input' method, calls:
feature.on_change_input(env, node, i, old_r, new_r) feature.on_change_input(env, node, i, old_r, new_r)
""" """
# TODO: ERROR HANDLING FOR LISTENERS (should it complete the change or revert it?)
if node == 'output': if node == 'output':
r = self.outputs[i] r = self.outputs[i]
if not r.type == new_r.type: if not r.type == new_r.type:
......
...@@ -112,6 +112,8 @@ class Linker(object): ...@@ -112,6 +112,8 @@ class Linker(object):
class Container(object): class Container(object):
def __init__(self, r, storage, readonly = False, strict = False, name = None): def __init__(self, r, storage, readonly = False, strict = False, name = None):
if not isinstance(storage, list) or not len(storage) >= 1:
raise TypeError("storage must be a list of length at least one")
#self.r = r #self.r = r
if isinstance(r, Type): if isinstance(r, Type):
self.type = r self.type = r
...@@ -127,6 +129,9 @@ class Container(object): ...@@ -127,6 +129,9 @@ class Container(object):
if self.readonly: if self.readonly:
raise Exception("Cannot set readonly storage: %s" % self.name) raise Exception("Cannot set readonly storage: %s" % self.name)
try: try:
if value is None:
self.storage[0] = None
return
if self.strict: if self.strict:
self.storage[0] = self.type.filter(value, strict = True) self.storage[0] = self.type.filter(value, strict = True)
else: else:
......
from functools import partial from functools import partial
import graph import graph
import sys
class AlreadyThere(Exception): class AlreadyThere(Exception):
...@@ -97,7 +98,12 @@ class ReplaceValidate(History, Validator): ...@@ -97,7 +98,12 @@ class ReplaceValidate(History, Validator):
def replace_all_validate(self, env, replacements): def replace_all_validate(self, env, replacements):
chk = env.checkpoint() chk = env.checkpoint()
for r, new_r in replacements: for r, new_r in replacements:
try:
env.replace(r, new_r) env.replace(r, new_r)
except Exception, e:
print >>sys.stderr, "<<!! BUG IN ENV.REPLACE OR A LISTENER !!>>", type(e), e
env.revert(chk) # this might fail; env.replace should never raise an exception (it kinda needs better internal error handling)
raise
try: try:
env.validate() env.validate()
except: except:
......
差异被折叠。
...@@ -1212,7 +1212,7 @@ class Subtensor(Op): ...@@ -1212,7 +1212,7 @@ class Subtensor(Op):
def __init__(self, idx_list): def __init__(self, idx_list):
def convert(entry, slice_ok=True): def convert(entry, slice_ok=True):
scal_types =[scal.int64, scal.int32, scal.int16, scal.int8] scal_types = [scal.int64, scal.int32, scal.int16, scal.int8]
tensor_types = [bscalar, iscalar, lscalar] tensor_types = [bscalar, iscalar, lscalar]
if isinstance(entry, gof.Result) and entry.type in scal_types: if isinstance(entry, gof.Result) and entry.type in scal_types:
return entry.type return entry.type
...@@ -1358,6 +1358,10 @@ class SetSubtensor(Subtensor): ...@@ -1358,6 +1358,10 @@ class SetSubtensor(Subtensor):
x.__setitem__(cdata, y) x.__setitem__(cdata, y)
out[0] = x out[0] = x
def split(x, splits_size, n_splits, axis=0):
the_split = Split(n_splits)
return the_split(x, axis, splits_size)
class Split(Op): class Split(Op):
"""Partition a `TensorResult` along some axis. """Partition a `TensorResult` along some axis.
...@@ -1366,9 +1370,9 @@ class Split(Op): ...@@ -1366,9 +1370,9 @@ class Split(Op):
x = vector() x = vector()
splits = lvector() splits = lvector()
# you have to declare right away how many split_points there will be. # you have to declare right away how many split_points there will be.
ra, rb, rc = split(x, axis=0, points=splits, n_splits=3) ra, rb, rc = split(x, splits, n_splits = 3, axis = 0)
f = compile([x, splits], [ra, rb, rc]) f = function([x, splits], [ra, rb, rc])
a, b, c = f([0,1,2,3,4,5,6], [3, 2, 1]) a, b, c = f([0,1,2,3,4,5,6], [3, 2, 1])
...@@ -2055,7 +2059,7 @@ def grad(cost, wrt, g_cost=None): ...@@ -2055,7 +2059,7 @@ def grad(cost, wrt, g_cost=None):
Tensor(dtype = p.type.dtype, broadcastable = []), Tensor(dtype = p.type.dtype, broadcastable = []),
numpy.asarray(0, dtype=p.type.dtype)) numpy.asarray(0, dtype=p.type.dtype))
if isinstance(wrt, list): if isinstance(wrt, (list, tuple)):
return [gmap.get(p, zero(p)) for p in wrt] return [gmap.get(p, zero(p)) for p in wrt]
else: else:
return gmap.get(wrt, zero(wrt)) return gmap.get(wrt, zero(wrt))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论