提交 bc2f50e5 authored 作者: nouiz's avatar nouiz

Merge pull request #340 from pascanur/fix_tests_sandbox_scan

Fix tests sandbox scan
...@@ -49,7 +49,8 @@ import numpy ...@@ -49,7 +49,8 @@ import numpy
from theano.compile import SharedVariable, function from theano.compile import SharedVariable, function
from theano import compile from theano import compile
from theano import gof from theano import gof
from theano.tensor import opt from theano.tensor import opt, TensorVariable
from theano.tensor.sharedvar import TensorSharedVariable
from theano import tensor from theano import tensor
from theano import config from theano import config
from theano.updates import Updates from theano.updates import Updates
...@@ -435,10 +436,10 @@ def scan(fn, ...@@ -435,10 +436,10 @@ def scan(fn,
pos = len(lengths) pos = len(lengths)
for sv in shared_inputs: for sv in shared_inputs:
if sv in update_d: if sv in update_d:
if isinstance(sv, TensorType): if isinstance(sv, (TensorVariable, TensorSharedVariable)):
# We can treat it as a sit sot # We can treat it as a sit sot
nw_state = scan_utils.expand( nw_state = scan_utils.expand(
tensor.unbroadcast(tensor.shape_padleft(sv, 0), T)) tensor.unbroadcast(tensor.shape_padleft(sv), 0), T)
additional_lengths.append(scalar_shared(numpy.int64(0), additional_lengths.append(scalar_shared(numpy.int64(0),
name='l%d' % pos)) name='l%d' % pos))
pos = pos + 1 pos = pos + 1
...@@ -454,6 +455,17 @@ def scan(fn, ...@@ -454,6 +455,17 @@ def scan(fn,
non_numeric_output_states.append(update_d[sv]) non_numeric_output_states.append(update_d[sv])
original_non_numeric_shared_variables.append(sv) original_non_numeric_shared_variables.append(sv)
# Replace shared variables in the update
_additional_output_states = []
replace = {}
for sv, buf in zip(original_numeric_shared_variables,
additional_input_states):
replace[sv] = buf[t]
for out in additional_output_states:
_additional_output_states.append(
scan_utils.clone(out, replace=replace))
additional_output_states = _additional_output_states
# 5.2 Collect inputs/outputs of the inner function # 5.2 Collect inputs/outputs of the inner function
inputs = [] inputs = []
outputs = [] outputs = []
...@@ -515,7 +527,7 @@ def scan(fn, ...@@ -515,7 +527,7 @@ def scan(fn,
for pos in xrange(len(states_and_outputs)): for pos in xrange(len(states_and_outputs)):
out = scan_utils.ScanPermutation(mintaps[pos])( out = scan_utils.ScanPermutation(mintaps[pos])(
scan_outputs_update_rules[pos], t) scan_outputs_update_rules[pos], t)
scan_outputs.append(out[mintap:]) scan_outputs.append(out[mintaps[pos]:])
# 5.6 Construct updates dictionary # 5.6 Construct updates dictionary
update_rules = scan_outputs_update_rules[len(states_and_outputs):] update_rules = scan_outputs_update_rules[len(states_and_outputs):]
updates = {} updates = {}
...@@ -553,7 +565,8 @@ def one_step_scan(fn, ...@@ -553,7 +565,8 @@ def one_step_scan(fn,
arg_info) arg_info)
# go through the taps # go through the taps
mintap = abs(numpy.min(arg_info['taps'])) mintap = abs(numpy.min(arg_info['taps']))
states_slices.append(arg_info['initial'][k + mintap]) states_slices.extend(
[arg_info['initial'][k + mintap] for k in arg_info['taps']])
# Re-order args # Re-order args
args = (inputs_slices + states_slices + parameters) args = (inputs_slices + states_slices + parameters)
......
...@@ -277,6 +277,7 @@ class ScanOp(PureOp): ...@@ -277,6 +277,7 @@ class ScanOp(PureOp):
for var, length, val in state_buffers: for var, length, val in state_buffers:
var.set_value(val[0], borrow=True) var.set_value(val[0], borrow=True)
length.set_value(val[0].shape[0], borrow=True) length.set_value(val[0].shape[0], borrow=True)
self.index.set_value(numpy.int64(0))
# grab fixed arguments # grab fixed arguments
fix_args = [x[0] for x in non_tensor_buffers] fix_args = [x[0] for x in non_tensor_buffers]
while cont and pos < node_input_storage[0][0]: while cont and pos < node_input_storage[0][0]:
...@@ -320,6 +321,7 @@ class ScanOp(PureOp): ...@@ -320,6 +321,7 @@ class ScanOp(PureOp):
for var, length, val in state_buffers: for var, length, val in state_buffers:
var.set_value(val[0], borrow=True) var.set_value(val[0], borrow=True)
length.set_value(val[0].shape[0], borrow=True) length.set_value(val[0].shape[0], borrow=True)
self.index.set_value(numpy.int64(0))
# grab fixed arguments # grab fixed arguments
fix_args = [x[0] for x in non_tensor_buffers] fix_args = [x[0] for x in non_tensor_buffers]
for dx in xrange(node_input_storage[0][0]): for dx in xrange(node_input_storage[0][0]):
......
...@@ -220,17 +220,29 @@ def canonical_arguments(sequences, ...@@ -220,17 +220,29 @@ def canonical_arguments(sequences,
# We cut the sequence such that seq[i] to correspond to # We cut the sequence such that seq[i] to correspond to
# seq[i-k] # seq[i-k]
if maxtap < 0: if maxtap < 0:
offset = abs(maxtap) offset_max = abs(maxtap)
else: else:
offset = 0 offset_max = 0
if mintap < 0:
offset_min = abs(mintap)
else:
offset_min = 0
nw_input = orig_input nw_input = orig_input
if maxtap == mintap and maxtap != 0: if maxtap == mintap and maxtap != 0:
nw_input = nw_input[:abs(maxtap)] if maxtap > 0:
elif maxtap - k != 0: nw_input = nw_input[maxtap:]
nw_input = nw_input[offset + k - mintap:\ else:
-(maxtap - k)] nw_input = nw_input[:maxtap]
else:
st = k + offset_min
if maxtap > 0:
ed = - (maxtap + offset_min - st)
else:
ed = - (offset_min -st)
if ed != 0:
nw_input = nw_input[st:ed]
else: else:
nw_input = nw_input[offset + k - mintap:] nw_input = nw_input[st:]
inputs.append(nw_input) inputs.append(nw_input)
else: else:
raise ValueError('Provided sequence makes no sense', str(input)) raise ValueError('Provided sequence makes no sense', str(input))
......
...@@ -166,7 +166,7 @@ def grab_scan_node(output): ...@@ -166,7 +166,7 @@ def grab_scan_node(output):
class TestScanUtils(unittest.TestCase): class TestScanUtils(unittest.TestCase):
def test_cloning_no_replace_strict_copy_inputs(self): def test001_cloning_no_replace_strict_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone # This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and # function that scan uses internally and that pfunc uses now and
# that users might want to use # that users might want to use
...@@ -185,7 +185,7 @@ class TestScanUtils(unittest.TestCase): ...@@ -185,7 +185,7 @@ class TestScanUtils(unittest.TestCase):
assert x in f2_inp assert x in f2_inp
assert y in f2_inp assert y in f2_inp
def test_cloning_no_replace_strict_not_copy_inputs(self): def test002_cloning_no_replace_strict_not_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone # This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and # function that scan uses internally and that pfunc uses now and
# that users might want to use # that users might want to use
...@@ -204,7 +204,7 @@ class TestScanUtils(unittest.TestCase): ...@@ -204,7 +204,7 @@ class TestScanUtils(unittest.TestCase):
assert not x in f2_inp assert not x in f2_inp
assert not y in f2_inp assert not y in f2_inp
def test_cloning_replace_strict_copy_inputs(self): def test003_cloning_replace_strict_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone # This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and # function that scan uses internally and that pfunc uses now and
# that users might want to use # that users might want to use
...@@ -223,7 +223,7 @@ class TestScanUtils(unittest.TestCase): ...@@ -223,7 +223,7 @@ class TestScanUtils(unittest.TestCase):
assert x in f2_inp assert x in f2_inp
assert y2 in f2_inp assert y2 in f2_inp
def test_cloning_replace_not_strict_copy_inputs(self): def test004_cloning_replace_not_strict_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone # This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and # function that scan uses internally and that pfunc uses now and
# that users might want to use # that users might want to use
...@@ -242,7 +242,7 @@ class TestScanUtils(unittest.TestCase): ...@@ -242,7 +242,7 @@ class TestScanUtils(unittest.TestCase):
assert x in f2_inp assert x in f2_inp
assert y2 in f2_inp assert y2 in f2_inp
def test_cloning_replace_strict_not_copy_inputs(self): def test005_cloning_replace_strict_not_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone # This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and # function that scan uses internally and that pfunc uses now and
# that users might want to use # that users might want to use
...@@ -261,7 +261,7 @@ class TestScanUtils(unittest.TestCase): ...@@ -261,7 +261,7 @@ class TestScanUtils(unittest.TestCase):
assert not x in f2_inp assert not x in f2_inp
assert not y2 in f2_inp assert not y2 in f2_inp
def test_cloning_replace_not_strict_not_copy_inputs(self): def test006_cloning_replace_not_strict_not_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone # This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and # function that scan uses internally and that pfunc uses now and
# that users might want to use # that users might want to use
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论