提交 d18eacd8 authored 作者: James Bergstra's avatar James Bergstra

merged

......@@ -156,7 +156,7 @@ class DestroyHandlerHelper2(toolbox.Bookkeeper):
self.debug_all_apps.remove(app)
#UPDATE self.clients
for i, input in enumerate(app.inputs):
for i, input in enumerate(set(app.inputs)):
del self.clients[input][app]
if getattr(app.op, 'destroy_map', {}):
......
......@@ -227,6 +227,7 @@ class Env(utils.object2):
For each feature that has a 'on_change_input' method, calls:
feature.on_change_input(env, node, i, old_r, new_r)
"""
# TODO: ERROR HANDLING FOR LISTENERS (should it complete the change or revert it?)
if node == 'output':
r = self.outputs[i]
if not r.type == new_r.type:
......
......@@ -112,6 +112,8 @@ class Linker(object):
class Container(object):
def __init__(self, r, storage, readonly = False, strict = False, name = None):
if not isinstance(storage, list) or not len(storage) >= 1:
raise TypeError("storage must be a list of length at least one")
#self.r = r
if isinstance(r, Type):
self.type = r
......@@ -127,6 +129,9 @@ class Container(object):
if self.readonly:
raise Exception("Cannot set readonly storage: %s" % self.name)
try:
if value is None:
self.storage[0] = None
return
if self.strict:
self.storage[0] = self.type.filter(value, strict = True)
else:
......
from functools import partial
import graph
import sys
class AlreadyThere(Exception):
......@@ -97,7 +98,12 @@ class ReplaceValidate(History, Validator):
def replace_all_validate(self, env, replacements):
chk = env.checkpoint()
for r, new_r in replacements:
env.replace(r, new_r)
try:
env.replace(r, new_r)
except Exception, e:
print >>sys.stderr, "<<!! BUG IN ENV.REPLACE OR A LISTENER !!>>", type(e), e
env.revert(chk) # this might fail; env.replace should never raise an exception (it kinda needs better internal error handling)
raise
try:
env.validate()
except:
......
差异被折叠。
......@@ -1212,7 +1212,7 @@ class Subtensor(Op):
def __init__(self, idx_list):
def convert(entry, slice_ok=True):
scal_types =[scal.int64, scal.int32, scal.int16, scal.int8]
scal_types = [scal.int64, scal.int32, scal.int16, scal.int8]
tensor_types = [bscalar, iscalar, lscalar]
if isinstance(entry, gof.Result) and entry.type in scal_types:
return entry.type
......@@ -1358,6 +1358,10 @@ class SetSubtensor(Subtensor):
x.__setitem__(cdata, y)
out[0] = x
def split(x, splits_size, n_splits, axis=0):
the_split = Split(n_splits)
return the_split(x, axis, splits_size)
class Split(Op):
"""Partition a `TensorResult` along some axis.
......@@ -1366,9 +1370,9 @@ class Split(Op):
x = vector()
splits = lvector()
# you have to declare right away how many split_points there will be.
ra, rb, rc = split(x, axis=0, points=splits, n_splits=3)
ra, rb, rc = split(x, splits, n_splits = 3, axis = 0)
f = compile([x, splits], [ra, rb, rc])
f = function([x, splits], [ra, rb, rc])
a, b, c = f([0,1,2,3,4,5,6], [3, 2, 1])
......@@ -2055,7 +2059,7 @@ def grad(cost, wrt, g_cost=None):
Tensor(dtype = p.type.dtype, broadcastable = []),
numpy.asarray(0, dtype=p.type.dtype))
if isinstance(wrt, list):
if isinstance(wrt, (list, tuple)):
return [gmap.get(p, zero(p)) for p in wrt]
else:
return gmap.get(wrt, zero(wrt))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论