提交 ffafbbc9 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Add support for compute_test_value inside the inner function of scan.

Test different cases of error and normal behaviour.
上级 41d3e005
import os, sys, traceback, warnings
import numpy import numpy
import unittest import unittest
import theano import theano
import warnings
from theano import config from theano import config
from theano import tensor as T from theano import tensor as T
from theano.tensor.basic import _allclose from theano.tensor.basic import _allclose
...@@ -164,7 +165,7 @@ class TestComputeTestValue(unittest.TestCase): ...@@ -164,7 +165,7 @@ class TestComputeTestValue(unittest.TestCase):
finally: finally:
theano.config.compute_test_value = orig_compute_test_value theano.config.compute_test_value = orig_compute_test_value
def notest_scan(self): def test_scan(self):
""" """
Do not run this test as the compute_test_value mechanism is known not to work with Scan. Do not run this test as the compute_test_value mechanism is known not to work with Scan.
TODO: fix scan to work with compute_test_value TODO: fix scan to work with compute_test_value
...@@ -172,15 +173,16 @@ class TestComputeTestValue(unittest.TestCase): ...@@ -172,15 +173,16 @@ class TestComputeTestValue(unittest.TestCase):
orig_compute_test_value = theano.config.compute_test_value orig_compute_test_value = theano.config.compute_test_value
try: try:
theano.config.compute_test_value = 'raise' theano.config.compute_test_value = 'raise'
#theano.config.compute_test_value = 'warn'
k = T.iscalar("k") k = T.iscalar("k")
A = T.vector("A") A = T.vector("A")
k.tag.test_value = 3 k.tag.test_value = 3
A.tag.test_value = numpy.random.rand(5) A.tag.test_value = numpy.random.rand(5)
def fx(prior_result, A): def fx(prior_result, A):
return prior_results * A return prior_result * A
# Symbolic description of the result # Symbolic description of the result
result, updates = theano.scan(fn=lambda prior_result, A: prior_result * A, result, updates = theano.scan(fn=fx,
outputs_info=T.ones_like(A), outputs_info=T.ones_like(A),
non_sequences=A, non_sequences=A,
n_steps=k) n_steps=k)
...@@ -192,3 +194,81 @@ class TestComputeTestValue(unittest.TestCase): ...@@ -192,3 +194,81 @@ class TestComputeTestValue(unittest.TestCase):
assert hasattr(final_result.tag, 'test_value') assert hasattr(final_result.tag, 'test_value')
finally: finally:
theano.config.compute_test_value = orig_compute_test_value theano.config.compute_test_value = orig_compute_test_value
def test_scan_err1(self):
# This test should fail when building fx for the first time
orig_compute_test_value = theano.config.compute_test_value
try:
theano.config.compute_test_value = 'raise'
k = T.iscalar("k")
A = T.matrix("A")
k.tag.test_value = 3
A.tag.test_value = numpy.random.rand(5,3)
def fx(prior_result, A):
return T.dot(prior_result, A)
# Since we have to inspect the traceback,
# we cannot simply use self.assertRaises()
try:
theano.scan(
fn=fx,
outputs_info=T.ones_like(A),
non_sequences=A,
n_steps=k)
assert False
except ValueError, e:
# Get traceback
tb = sys.exc_info()[2]
# Get frame info 3 layers up
frame_info = traceback.extract_tb(tb)[-3]
# We should be in the "fx" function defined above
assert os.path.split(frame_info[0])[1] == 'test_compute_test_value.py'
assert frame_info[2] == 'fx'
finally:
theano.config.compute_test_value = orig_compute_test_value
def test_scan_err2(self):
# This test should not fail when building fx for the first time,
# but when calling the scan's perform()
orig_compute_test_value = theano.config.compute_test_value
try:
theano.config.compute_test_value = 'raise'
k = T.iscalar("k")
A = T.matrix("A")
k.tag.test_value = 3
A.tag.test_value = numpy.random.rand(5,3)
def fx(prior_result, A):
return T.dot(prior_result, A)
self.assertRaises(ValueError,
theano.scan,
fn=fx,
outputs_info=T.ones_like(A.T),
non_sequences=A,
n_steps=k)
# Since we have to inspect the traceback,
# we cannot simply use self.assertRaises()
try:
theano.scan(
fn=fx,
outputs_info=T.ones_like(A.T),
non_sequences=A,
n_steps=k)
assert False
except ValueError, e:
# Get traceback
tb = sys.exc_info()[2]
# Get last frame info
frame_info = traceback.extract_tb(tb)[-1]
# We should be in scan_op.py, function 'perform'
assert os.path.split(frame_info[0])[1] == 'scan_op.py'
assert frame_info[2] == 'perform'
finally:
theano.config.compute_test_value = orig_compute_test_value
...@@ -39,6 +39,7 @@ __contact__ = "Razvan Pascanu <r.pascanu@gmail>" ...@@ -39,6 +39,7 @@ __contact__ = "Razvan Pascanu <r.pascanu@gmail>"
import itertools import itertools
import logging import logging
import numpy import numpy
import warnings
from theano.compile import SharedVariable, function from theano.compile import SharedVariable, function
from theano import compile from theano import compile
...@@ -413,10 +414,21 @@ def scan( fn ...@@ -413,10 +414,21 @@ def scan( fn
# If not we need to use copies, that will be replaced at # If not we need to use copies, that will be replaced at
# each frame by the corresponding slice # each frame by the corresponding slice
_seq_val = tensor.as_tensor_variable(seq['input'])
nw_slice = _seq_val[0].type()
actual_slice = seq['input'][k-mintap] actual_slice = seq['input'][k-mintap]
_seq_val = tensor.as_tensor_variable(seq['input'])
_seq_val_slice = _seq_val[k-mintap]
nw_slice = _seq_val_slice.type()
# Try to transfer test_value to the new variable
if config.compute_test_value != 'off':
try:
nw_slice.tag.test_value = gof.Op._get_test_value(_seq_val_slice)
except AttributeError, e:
if config.compute_test_value != 'ignore':
# No need to print a warning or raise an error now,
# it will be done when fn will be called.
info(('Cannot compute test value for the inner '
'function of scan, input value missing'), e)
# Add names to slices for debugging and pretty printing .. # Add names to slices for debugging and pretty printing ..
# that is if the input already has a name # that is if the input already has a name
...@@ -451,6 +463,7 @@ def scan( fn ...@@ -451,6 +463,7 @@ def scan( fn
inner_slices.append( actual_slice ) inner_slices.append( actual_slice )
n_seqs += 1 n_seqs += 1
# Since we've added all sequences now we need to level them up based on # Since we've added all sequences now we need to level them up based on
# n_steps or their different shapes # n_steps or their different shapes
lengths_vec = [] lengths_vec = []
...@@ -533,8 +546,21 @@ def scan( fn ...@@ -533,8 +546,21 @@ def scan( fn
actual_arg = init_out['initial'] actual_arg = init_out['initial']
arg = safe_new(init_out['initial']) arg = safe_new(init_out['initial'])
# Try to transfer test_value to the new variable
if config.compute_test_value != 'off':
try:
arg.tag.test_value = gof.Op._get_test_value(actual_arg)
except AttributeError, e:
if config.compute_test_value != 'ignore':
# No need to print a warning or raise an error now,
# it will be done when fn will be called.
info(('Cannot compute test value for the inner '
'function of scan, input value missing'), e)
if getattr(init_out['initial'],'name', None) is not None: if getattr(init_out['initial'],'name', None) is not None:
arg.name = init_out['initial'].name+'[t-1]' arg.name = init_out['initial'].name+'[t-1]'
# We need now to allocate space for storing the output and copy # We need now to allocate space for storing the output and copy
# the initial state over. We do this using the expand function # the initial state over. We do this using the expand function
# defined in scan utils # defined in scan utils
...@@ -576,7 +602,19 @@ def scan( fn ...@@ -576,7 +602,19 @@ def scan( fn
# create a new slice # create a new slice
actual_nw_slice = init_out['initial'][k+mintap] actual_nw_slice = init_out['initial'][k+mintap]
_init_out_var = tensor.as_tensor_variable(init_out['initial']) _init_out_var = tensor.as_tensor_variable(init_out['initial'])
nw_slice = _init_out_var[0].type() _init_out_var_slice = _init_out_var[k+mintap]
nw_slice = _init_out_var_slice.type()
# Try to transfer test_value to the new variable
if config.compute_test_value != 'off':
try:
nw_slice.tag.test_value = Op._get_test_value(_init_out_var_slice)
except AttributeError, e:
if config.compute_test_value != 'ignore':
# No need to print a warning or raise an error now,
# it will be done when fn will be called.
info(('Cannot compute test value for the inner '
'function of scan, input value missing.'), e)
# give it a name or debugging and pretty printing # give it a name or debugging and pretty printing
if getattr(init_out['initial'],'name', None) is not None: if getattr(init_out['initial'],'name', None) is not None:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论