提交 6cd40e55 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

I added reduce / foldl/ foldr version of scan; reason : I can add scan into rbm…

I added reduce / foldl/ foldr version of scan; reason : I can add scan into rbm tutorial and bypass writing the optimization about space .. though I have to write that as well at one point. Also reduce /foldl/ foldr are nice shortcuts to scan
上级 f6518a45
...@@ -60,7 +60,7 @@ FancyModule = Module ...@@ -60,7 +60,7 @@ FancyModule = Module
from printing import \ from printing import \
pprint, pp pprint, pp
from scan import scan,map from scan import scan,map, reduce, foldl, foldr
import tensor import tensor
import scalar import scalar
......
...@@ -428,7 +428,6 @@ def pydotprint(fct, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn ...@@ -428,7 +428,6 @@ def pydotprint(fct, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn
for node_idx,node in enumerate(topo): for node_idx,node in enumerate(topo):
astr=apply_name(node) astr=apply_name(node)
g.add_node(pd.Node(astr,shape='box')) g.add_node(pd.Node(astr,shape='box'))
for id,var in enumerate(node.inputs): for id,var in enumerate(node.inputs):
varstr=var_name(var) varstr=var_name(var)
...@@ -462,3 +461,90 @@ def pydotprint(fct, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn ...@@ -462,3 +461,90 @@ def pydotprint(fct, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn
print 'The output file is available at',outfile print 'The output file is available at',outfile
def pydot_var(vars, outfile=os.path.join(config.compiledir,'theano.pydotprint.png'), depth = -1):
''' Identical to pydotprint just that it starts from a variable instead
of a compiled function. Could be useful ? '''
try:
import pydot as pd
except:
print "failed to import pydot. Yous must install pydot for this function to work."
return
g=pd.Dot()
my_list = {}
if type(vars) not in (list,tuple):
vars = [vars]
var_str = {}
def var_name(var):
if var in var_str:
return var_str[var]
if var.name is not None:
varstr = var.name
elif isinstance(var,gof.Constant):
dstr = str(var.data)
if '\n' in dstr:
dstr = dstr[:dstr.index('\n')]
if len(dstr) > 30:
dstr = dstr[:27]+'...'
varstr = '%s [%s]'% (dstr, str(var.type))
else:
#a var id is needed as otherwise var with the same type will be merged in the graph.
varstr = str(var.type)
varstr += ' ' + str(len(var_str))
var_str[var]=varstr
return varstr
def apply_name(node):
return str(node.op).replace(':','_')
def plot_apply(app, d):
if d == 0:
return
if app in my_list:
return
astr = apply_name(app) + '_' + str(len(my_list.keys()))
my_list[app] = astr
g.add_node(pd.Node(astr, shape='box'))
for i,nd in enumerate(app.inputs):
if nd not in my_list:
varastr = var_name(nd) + '_' + str(len(my_list.keys()))
my_list[nd] = varastr
g.add_node(pd.Node(varastr))
else:
varastr = my_list[nd]
label = ''
if len(app.inputs)>1:
label = str(i)
g.add_edge(pd.Edge(varastr, astr, label = label))
for i,nd in enumerate(app.outputs):
if nd not in my_list:
varastr = var_name(nd) + '_' + str(len(my_list.keys()))
my_list[nd] = varastr
g.add_node(pd.Node(varastr))
else:
varastr = my_list[nd]
label = ''
if len(app.outputs) > 1:
label = str(i)
g.add_edge(pd.Edge(astr, varastr,label = label))
for nd in app.inputs:
if nd.owner:
plot_apply(nd.owner, d-1)
for nd in vars:
if nd.owner:
plot_apply(nd.owner, depth)
g.write_png(outfile, prog='dot')
print 'The output file is available at',outfile
...@@ -67,15 +67,115 @@ def hash_listsDictsTuples(x): ...@@ -67,15 +67,115 @@ def hash_listsDictsTuples(x):
################################### ###################################
## Implement specific function calls : map, reduce, generate ## Implement specific function calls : map, reduce, generate
def map(fn, sequences, non_sequences = [], n_steps =0, def map(fn, sequences, non_sequences = [],
truncate_gradient = -1, go_backwards = False, truncate_gradient = -1, go_backwards = False,
mode = 'FAST_RUN'): mode = 'FAST_RUN'):
''' Similar behaviour as python map
:param fn: the function to be applied over the elements in
sequences ( see scan `fn` for more info)
:param sequences: list of arrays over which map should
iterate (see scan for more info)
:param non_sequences: list of other arguments of `fn` over which
map shouldn't iterate (see scan for more info)
:param truncate_gradient: see scan for more info
:param go_backwards: if map should also inverse the order in the arrays
see scan for more info
:param mode: see scan
'''
return scan(fn, sequences= sequences, outputs_info = [],non_sequences= non_sequences, return scan(fn, sequences= sequences, outputs_info = [],non_sequences= non_sequences,
truncate_gradient= truncate_gradient, truncate_gradient= truncate_gradient,
go_backwards= go_backwards, mode = mode) go_backwards= go_backwards, mode = mode)
def reduce(fn, sequences, outputs_info, non_sequences = [], go_backwards = False, mode = 'FAST_RUN'):
''' Similar behaviour as python reduce
:param fn: the function to be applied over the elements in
sequences ( see scan `fn` for more info)
:param outputs_info: information about outputs (mainly the initial state
of each )
:param sequences: list of arrays over which reduce should
iterate (see scan for more info)
:param non_sequences: list of other arguments of `fn` over which
reduce shouldn't iterate (see scan for more info)
:param go_backwards: if reduce should also inverse the order in the arrays
see scan for more info
:param mode: see scan
'''
# Specify that you only want the last value of the output
if type(outputs_info) not in (list,tuple):
outs_info = [outputs_info]
else:
outs_info = [outputs_info]
for i,out_info in enumerate(outs_info):
if out_info:
if not type(out_info) == dict:
outs_info[i] = dict(initial = out_info, taps = [-1], store_steps = 1)
else:
# we force to use only the last step
# and store only the alst step
outs_info[i]['taps'] = [-1]
outs_info[i]['store_steps'] = 1
# NOTE : Maybe some errors can be detected here were we can give
# more meaningfull error messages than in scan RP
return scan(fn, sequences = sequences, outputs_info = outs_info,
non_sequences = non_sequences, go_backwards = go_backwards,
truncate_gradient = 1, mode = mode)
def foldl(fn, sequences, outputs_info, non_sequences = [], mode = 'FAST_RUN'):
''' Similar behaviour as haskell foldl
:param fn: the function to be applied over the elements in
sequences ( see scan `fn` for more info)
:param sequences: list of arrays over which foldl should
iterate (see scan for more info)
:param outputs_info: information about outputs (mainly the initial state
of each )
:param non_sequences: list of other arguments of `fn` over which
foldl shouldn't iterate (see scan for more info)
:param mode: see scan
'''
return reduce(fn = fn, sequences = sequences, outputs_info = outputs_info,
non_sequences= non_sequences, go_backwards = False, mode = mode)
def foldr(fn, sequences, outputs_info, non_sequences = [], mode = 'FAST_RUN'):
''' Similar behaviour as haskell foldr
:param fn: the function to be applied over the elements in
sequences ( see scan `fn` for more info)
:param sequences: list of arrays over which foldr should
iterate (see scan for more info)
:param outputs_info: information about outputs (mainly the initial state
of each )
:param non_sequences: list of other arguments of `fn` over which
foldr shouldn't iterate (see scan for more info)
:param truncate_gradient: see scan for more info
:param mode: see scan
'''
return reduce(fn = fn,sequences = sequences, outputs_info = outputs_info,
non_sequences = non_sequences, go_backwards = True, mode = mode)
# CONSIDER ALTERNATE CALLING CONVENTIONS: # CONSIDER ALTERNATE CALLING CONVENTIONS:
# simple: # simple:
# scan(fn, [a,b], [c]) # scan(fn, [a,b], [c])
...@@ -423,6 +523,12 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[], ...@@ -423,6 +523,12 @@ def scan(fn, sequences=[], outputs_info=[], non_sequences=[],
fromIdx = dummy_notshared_ins + dummy_notshared_init_outs fromIdx = dummy_notshared_ins + dummy_notshared_init_outs
store_steps = [ 0 for i in xrange(n_outs)] store_steps = [ 0 for i in xrange(n_outs)]
for i in xrange(n_outs):
if outs_info[i].get('store_steps', None):
print 'here'
store_steps[i] = outs_info[i]['store_steps']
# add shared variable that act as outputs # add shared variable that act as outputs
# #
n_extended_outs = n_outs n_extended_outs = n_outs
......
...@@ -559,7 +559,7 @@ class T_Scan(unittest.TestCase): ...@@ -559,7 +559,7 @@ class T_Scan(unittest.TestCase):
def test_map(self): def test_map(self):
v = theano.tensor.vector() v = theano.tensor.vector()
abs_expr,abs_updates = theano.map(lambda x: abs(x), v,[],n_steps =0, abs_expr,abs_updates = theano.map(lambda x: abs(x), v,[],
truncate_gradient = -1, go_backwards = False) truncate_gradient = -1, go_backwards = False)
f = theano.function([v],abs_expr,updates = abs_updates) f = theano.function([v],abs_expr,updates = abs_updates)
...@@ -598,6 +598,17 @@ class T_Scan(unittest.TestCase): ...@@ -598,6 +598,17 @@ class T_Scan(unittest.TestCase):
theano_values = f2(v_u,v_x0, W_in, W) theano_values = f2(v_u,v_x0, W_in, W)
assert numpy.allclose( theano_values , v_out) assert numpy.allclose( theano_values , v_out)
def test_reduce(self):
v = theano.tensor.vector()
s = theano.tensor.scalar()
result, updates = theano.reduce(lambda x,y: x+y, v,s)
f = theano.function([v,s], result, updates = updates)
rng = numpy.random.RandomState(utt.fetch_seed())
v_v = rng.uniform( size = (5,), low = -5., high = 5.)
assert ( numpy.sum(v_v) == f(v_v, 0.) )
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论