提交 9815eca2 authored 作者: nouiz's avatar nouiz

Merge pull request #923 from mrocklin/give_variables_names

add function to give variables unique names
import theano
from theano.gof.utils import give_variables_names, unique
def test_give_variables_names():
x = theano.tensor.matrix('x')
y = x + 1
z = theano.tensor.dot(x, y)
variables = (x, y, z)
give_variables_names(variables)
assert all(var.name for var in variables)
assert unique([var.name for var in variables])
def test_give_variables_names_idempotence():
x = theano.tensor.matrix('x')
y = x + 1
z = theano.tensor.dot(x, y)
variables = (x, y, z)
give_variables_names(variables)
names = [var.name for var in variables]
give_variables_names(variables)
names2 = [var.name for var in variables]
assert names == names2
def test_give_variables_names_small():
x = theano.tensor.matrix('x')
y = theano.tensor.dot(x, x)
fgraph = theano.FunctionGraph((x,), (y,))
give_variables_names(fgraph.variables)
assert all(var.name for var in fgraph.variables)
assert unique([var.name for var in fgraph.variables])
...@@ -7,7 +7,7 @@ import re, os, traceback ...@@ -7,7 +7,7 @@ import re, os, traceback
def add_tag_trace(thing): def add_tag_trace(thing):
"""Add tag.trace to an node or variable. """Add tag.trace to an node or variable.
The argument is returned after being affected (inplace). The argument is returned after being affected (inplace).
""" """
limit = config.traceback.limit limit = config.traceback.limit
...@@ -20,7 +20,7 @@ def hashgen(): ...@@ -20,7 +20,7 @@ def hashgen():
return hashgen.next return hashgen.next
hashgen.next = 0 hashgen.next = 0
class MethodNotDefined(Exception): class MethodNotDefined(Exception):
""" """
To be raised by functions defined as part of an interface. To be raised by functions defined as part of an interface.
...@@ -78,7 +78,7 @@ def memoize(f): ...@@ -78,7 +78,7 @@ def memoize(f):
def deprecated(filename, msg=''): def deprecated(filename, msg=''):
"""Decorator which will print a warning message on the first call. """Decorator which will print a warning message on the first call.
Use it like this: Use it like this:
@deprecated('myfile', 'do something different...') @deprecated('myfile', 'do something different...')
...@@ -107,9 +107,9 @@ def uniq(seq): ...@@ -107,9 +107,9 @@ def uniq(seq):
def difference(seq1, seq2): def difference(seq1, seq2):
""" """
Returns all elements in seq1 which are not in seq2: i.e seq1\seq2 Returns all elements in seq1 which are not in seq2: i.e seq1\seq2
""" """
try: try:
# try to use O(const * len(seq1)) algo # try to use O(const * len(seq1)) algo
if len(seq2) < 4: # I'm guessing this threshold -JB if len(seq2) < 4: # I'm guessing this threshold -JB
raise Exception('not worth it') raise Exception('not worth it')
...@@ -131,7 +131,7 @@ def partition(f, seq): ...@@ -131,7 +131,7 @@ def partition(f, seq):
else: else:
seqf.append(elem) seqf.append(elem)
return seqt, seqf return seqt, seqf
def attr_checker(*attrs): def attr_checker(*attrs):
def f(candidate): def f(candidate):
for attr in attrs: for attr in attrs:
...@@ -204,7 +204,7 @@ def toposort(prereqs_d): ...@@ -204,7 +204,7 @@ def toposort(prereqs_d):
# for x, y in prereqs_d.items(): # for x, y in prereqs_d.items():
# all2.update(y) # all2.update(y)
# print all1.difference(all2) # print all1.difference(all2)
seq = [] seq = []
done = set() done = set()
postreqs_d = {} postreqs_d = {}
...@@ -330,3 +330,29 @@ def flatten(a): ...@@ -330,3 +330,29 @@ def flatten(a):
return l return l
else: else:
return [a] return [a]
def unique(x):
return len(set(x)) == len(x)
def hist(coll):
counts = {}
for elem in coll:
counts[elem] = counts.get(elem, 0) + 1
return counts
def give_variables_names(variables):
""" Gives unique names to an iterable of variables. Modifies input.
This function is idempotent."""
names = map(lambda var: var.name, variables)
h = hist(names)
bad_var = lambda var: not var.name or h[var.name] > 1
for i, var in enumerate(filter(bad_var, variables)):
var.name = (var.name or "") + "_%d"%i
if not unique(map(str, variables)):
raise ValueError("Not all variables have unique names."
"Maybe you've named some of the variables identically")
return variables
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论