提交 bc2f50e5 authored 作者: nouiz's avatar nouiz

Merge pull request #340 from pascanur/fix_tests_sandbox_scan

Fix tests sandbox scan
......@@ -49,7 +49,8 @@ import numpy
from theano.compile import SharedVariable, function
from theano import compile
from theano import gof
from theano.tensor import opt
from theano.tensor import opt, TensorVariable
from theano.tensor.sharedvar import TensorSharedVariable
from theano import tensor
from theano import config
from theano.updates import Updates
......@@ -435,10 +436,10 @@ def scan(fn,
pos = len(lengths)
for sv in shared_inputs:
if sv in update_d:
if isinstance(sv, TensorType):
if isinstance(sv, (TensorVariable, TensorSharedVariable)):
# We can treat it as a sit sot
nw_state = scan_utils.expand(
tensor.unbroadcast(tensor.shape_padleft(sv, 0), T))
tensor.unbroadcast(tensor.shape_padleft(sv), 0), T)
additional_lengths.append(scalar_shared(numpy.int64(0),
name='l%d' % pos))
pos = pos + 1
......@@ -454,6 +455,17 @@ def scan(fn,
non_numeric_output_states.append(update_d[sv])
original_non_numeric_shared_variables.append(sv)
# Replace shared variables in the update
_additional_output_states = []
replace = {}
for sv, buf in zip(original_numeric_shared_variables,
additional_input_states):
replace[sv] = buf[t]
for out in additional_output_states:
_additional_output_states.append(
scan_utils.clone(out, replace=replace))
additional_output_states = _additional_output_states
# 5.2 Collect inputs/outputs of the inner function
inputs = []
outputs = []
......@@ -515,7 +527,7 @@ def scan(fn,
for pos in xrange(len(states_and_outputs)):
out = scan_utils.ScanPermutation(mintaps[pos])(
scan_outputs_update_rules[pos], t)
scan_outputs.append(out[mintap:])
scan_outputs.append(out[mintaps[pos]:])
# 5.6 Construct updates dictionary
update_rules = scan_outputs_update_rules[len(states_and_outputs):]
updates = {}
......@@ -553,7 +565,8 @@ def one_step_scan(fn,
arg_info)
# go through the taps
mintap = abs(numpy.min(arg_info['taps']))
states_slices.append(arg_info['initial'][k + mintap])
states_slices.extend(
[arg_info['initial'][k + mintap] for k in arg_info['taps']])
# Re-order args
args = (inputs_slices + states_slices + parameters)
......
......@@ -277,6 +277,7 @@ class ScanOp(PureOp):
for var, length, val in state_buffers:
var.set_value(val[0], borrow=True)
length.set_value(val[0].shape[0], borrow=True)
self.index.set_value(numpy.int64(0))
# grab fixed arguments
fix_args = [x[0] for x in non_tensor_buffers]
while cont and pos < node_input_storage[0][0]:
......@@ -320,6 +321,7 @@ class ScanOp(PureOp):
for var, length, val in state_buffers:
var.set_value(val[0], borrow=True)
length.set_value(val[0].shape[0], borrow=True)
self.index.set_value(numpy.int64(0))
# grab fixed arguments
fix_args = [x[0] for x in non_tensor_buffers]
for dx in xrange(node_input_storage[0][0]):
......
......@@ -220,17 +220,29 @@ def canonical_arguments(sequences,
# We cut the sequence such that seq[i] to correspond to
# seq[i-k]
if maxtap < 0:
offset = abs(maxtap)
offset_max = abs(maxtap)
else:
offset = 0
offset_max = 0
if mintap < 0:
offset_min = abs(mintap)
else:
offset_min = 0
nw_input = orig_input
if maxtap == mintap and maxtap != 0:
nw_input = nw_input[:abs(maxtap)]
elif maxtap - k != 0:
nw_input = nw_input[offset + k - mintap:\
-(maxtap - k)]
if maxtap > 0:
nw_input = nw_input[maxtap:]
else:
nw_input = nw_input[:maxtap]
else:
st = k + offset_min
if maxtap > 0:
ed = - (maxtap + offset_min - st)
else:
ed = - (offset_min -st)
if ed != 0:
nw_input = nw_input[st:ed]
else:
nw_input = nw_input[offset + k - mintap:]
nw_input = nw_input[st:]
inputs.append(nw_input)
else:
raise ValueError('Provided sequence makes no sense', str(input))
......
......@@ -166,7 +166,7 @@ def grab_scan_node(output):
class TestScanUtils(unittest.TestCase):
def test_cloning_no_replace_strict_copy_inputs(self):
def test001_cloning_no_replace_strict_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
......@@ -185,7 +185,7 @@ class TestScanUtils(unittest.TestCase):
assert x in f2_inp
assert y in f2_inp
def test_cloning_no_replace_strict_not_copy_inputs(self):
def test002_cloning_no_replace_strict_not_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
......@@ -204,7 +204,7 @@ class TestScanUtils(unittest.TestCase):
assert not x in f2_inp
assert not y in f2_inp
def test_cloning_replace_strict_copy_inputs(self):
def test003_cloning_replace_strict_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
......@@ -223,7 +223,7 @@ class TestScanUtils(unittest.TestCase):
assert x in f2_inp
assert y2 in f2_inp
def test_cloning_replace_not_strict_copy_inputs(self):
def test004_cloning_replace_not_strict_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
......@@ -242,7 +242,7 @@ class TestScanUtils(unittest.TestCase):
assert x in f2_inp
assert y2 in f2_inp
def test_cloning_replace_strict_not_copy_inputs(self):
def test005_cloning_replace_strict_not_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
......@@ -261,7 +261,7 @@ class TestScanUtils(unittest.TestCase):
assert not x in f2_inp
assert not y2 in f2_inp
def test_cloning_replace_not_strict_not_copy_inputs(self):
def test006_cloning_replace_not_strict_not_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论