提交 7a9862b2 authored 作者: Jakub Sygnowski's avatar Jakub Sygnowski

assert empty output for empty output_subset

上级 73620123
...@@ -231,11 +231,12 @@ def test_partial_function_with_updates(): ...@@ -231,11 +231,12 @@ def test_partial_function_with_updates():
optimizer=None, linker=linker_name)) optimizer=None, linker=linker_name))
g = theano.function([x], [x - 6], updates=[(y, y + 3)], mode=Mode( g = theano.function([x], [x - 6], updates=[(y, y + 3)], mode=Mode(
optimizer=None, linker=linker_name)) optimizer=None, linker=linker_name))
f(3, output_subset=[])
assert(y.get_value() == 4) assert f(3, output_subset=[]) == []
assert(g(30, output_subset=[0]) == [24]) assert y.get_value() == 4
g(40, output_subset=[]) assert g(30, output_subset=[0]) == [24]
assert(y.get_value() == 10) assert g(40, output_subset=[]) == []
assert y.get_value() == 10
check_updates(vm.VM_Linker(allow_partial_eval=True)) check_updates(vm.VM_Linker(allow_partial_eval=True))
check_updates('cvm') check_updates('cvm')
......
...@@ -332,8 +332,8 @@ class Stack(VM): ...@@ -332,8 +332,8 @@ class Stack(VM):
def __init__(self, nodes, thunks, pre_call_clear, def __init__(self, nodes, thunks, pre_call_clear,
storage_map, compute_map, fgraph, allow_gc, storage_map, compute_map, fgraph, allow_gc,
dependencies=None, callback=None, callback_input=None, n_updates, dependencies=None, callback=None,
n_updates=0): callback_input=None):
super(Stack, self).__init__(nodes, thunks, pre_call_clear) super(Stack, self).__init__(nodes, thunks, pre_call_clear)
self.allow_gc = allow_gc self.allow_gc = allow_gc
...@@ -420,8 +420,8 @@ class Stack(VM): ...@@ -420,8 +420,8 @@ class Stack(VM):
# apply_stack contains nodes # apply_stack contains nodes
if output_subset is not None: if output_subset is not None:
first_updated = len(self.outputs) - self.n_updates first_updated = len(self.outputs) - self.n_updates
output_subset = output_subset + range(first_updated, output_subset = output_subset + list(range(first_updated,
len(self.outputs)) len(self.outputs)))
apply_stack =\ apply_stack =\
[self.outputs[i].owner for i in output_subset [self.outputs[i].owner for i in output_subset
if self.outputs[i].owner] if self.outputs[i].owner]
...@@ -865,10 +865,10 @@ class VM_Linker(link.LocalLinker): ...@@ -865,10 +865,10 @@ class VM_Linker(link.LocalLinker):
nodes, thunks, pre_call_clear, nodes, thunks, pre_call_clear,
storage_map, compute_map, storage_map, compute_map,
self.fgraph, self.allow_gc, self.fgraph, self.allow_gc,
len(updated_vars),
dependencies=deps, dependencies=deps,
callback=self.callback, callback=self.callback,
callback_input=self.callback_input, callback_input=self.callback_input)
n_updates=len(updated_vars))
elif self.use_cloop: elif self.use_cloop:
# create a map from nodes to ints and vars to ints # create a map from nodes to ints and vars to ints
nodes_idx = {} nodes_idx = {}
...@@ -1006,8 +1006,8 @@ class VM_Linker(link.LocalLinker): ...@@ -1006,8 +1006,8 @@ class VM_Linker(link.LocalLinker):
nodes, thunks, pre_call_clear, nodes, thunks, pre_call_clear,
storage_map, compute_map, storage_map, compute_map,
self.fgraph, self.allow_gc, self.fgraph, self.allow_gc,
len(updated_vars),
dependencies=deps, dependencies=deps,
n_updates=len(updated_vars)
) )
return vm return vm
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论