提交 382b607c authored 作者: nouiz's avatar nouiz

Merge pull request #229 from pascanur/depricated_scan

Depricated scan
...@@ -391,12 +391,11 @@ def scan(fn, ...@@ -391,12 +391,11 @@ def scan(fn,
if isinstance(outs_info[i], dict): if isinstance(outs_info[i], dict):
# DEPRECATED : # DEPRECATED :
if outs_info[i].get('return_steps', None): if outs_info[i].get('return_steps', None):
_logger.warning( raise ValueError(
"Using `return_steps` has been deprecated. " "Using `return_steps` has been deprecated. "
"Simply select the entries you need using a " "Simply select the entries you need using a "
"subtensor. Scan will optimize memory " "subtensor. Scan will optimize memory "
"consumption, so do not worry about that.") "consumption, so do not worry about that.")
return_steps[i] = outs_info[i]['return_steps']
# END # END
if not isinstance(outs_info[i], dict): if not isinstance(outs_info[i], dict):
......
...@@ -208,8 +208,7 @@ def get_updates_and_outputs(ls): ...@@ -208,8 +208,7 @@ def get_updates_and_outputs(ls):
raise ValueError(error_msg) raise ValueError(error_msg)
elif is_updates(ls[0]): elif is_updates(ls[0]):
if is_outputs(ls[1]): if is_outputs(ls[1]):
_logger.warning(deprication_msg) raise ValueError(deprication_msg)
return (None, _list(ls[1]), dict(ls[0]))
elif is_condition(ls[1]): elif is_condition(ls[1]):
return (ls[1].condition, [], dict(ls[0])) return (ls[1].condition, [], dict(ls[0]))
else: else:
...@@ -225,15 +224,6 @@ def get_updates_and_outputs(ls): ...@@ -225,15 +224,6 @@ def get_updates_and_outputs(ls):
raise ValueError(error_msg) raise ValueError(error_msg)
else: else:
raise ValueError(error_msg) raise ValueError(error_msg)
elif is_updates(ls[0]):
if is_outputs(ls[1]):
if is_condition(ls[2]):
_logger.warning(deprication_msg)
return (ls[2].condition, _list(ls[1]), dict(ls[0]))
else:
raise ValueError(error_msg)
else:
raise ValueError(error_msg)
else: else:
raise ValueError(error_msg) raise ValueError(error_msg)
......
...@@ -2560,18 +2560,21 @@ class T_Scan(unittest.TestCase): ...@@ -2560,18 +2560,21 @@ class T_Scan(unittest.TestCase):
theano.dot(x_tm1, W), theano.dot(x_tm1, W),
y_tm1 + theano.dot(x_tm1, W_out)] y_tm1 + theano.dot(x_tm1, W_out)]
outputs, updates = theano.scan( f_rnn_cmpl rval, updates = theano.scan( f_rnn_cmpl
, [ u1 , [ u1
, u2] , u2]
, [ dict(store_steps = 3) , [ None
, dict(initial = x0, return_steps = 2) , dict(initial = x0)
, dict(initial=y0, taps=[-1,-3], , dict(initial=y0, taps=[-1,-3])]
return_steps = 4)]
, W_in1 , W_in1
, n_steps = None , n_steps = None
, truncate_gradient = -1 , truncate_gradient = -1
, go_backwards = False) , go_backwards = False)
outputs = []
outputs += [rval[0][-3:]]
outputs += [rval[1][-2:]]
outputs += [rval[2][-4:]]
f4 = theano.function([u1,u2,x0,y0,W_in1], outputs f4 = theano.function([u1,u2,x0,y0,W_in1], outputs
, updates = updates , updates = updates
, allow_input_downcast = True , allow_input_downcast = True
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论