提交 2692a34a authored 作者: Matthew Rocklin's avatar Matthew Rocklin

add function to give variables unique names

上级 00183e72
import theano
from theano.gof.utils import give_variables_names, unique
def test_variables_with_names():
x = theano.tensor.matrix('x')
y = x + 1
z = theano.tensor.dot(x, y)
variables = {x,y,z}
give_variables_names(variables)
assert all(var.name for var in variables)
assert unique([var.name for var in variables])
...@@ -7,7 +7,7 @@ import re, os, traceback ...@@ -7,7 +7,7 @@ import re, os, traceback
def add_tag_trace(thing): def add_tag_trace(thing):
"""Add tag.trace to an node or variable. """Add tag.trace to an node or variable.
The argument is returned after being affected (inplace). The argument is returned after being affected (inplace).
""" """
limit = config.traceback.limit limit = config.traceback.limit
...@@ -20,7 +20,7 @@ def hashgen(): ...@@ -20,7 +20,7 @@ def hashgen():
return hashgen.next return hashgen.next
hashgen.next = 0 hashgen.next = 0
class MethodNotDefined(Exception): class MethodNotDefined(Exception):
""" """
To be raised by functions defined as part of an interface. To be raised by functions defined as part of an interface.
...@@ -78,7 +78,7 @@ def memoize(f): ...@@ -78,7 +78,7 @@ def memoize(f):
def deprecated(filename, msg=''): def deprecated(filename, msg=''):
"""Decorator which will print a warning message on the first call. """Decorator which will print a warning message on the first call.
Use it like this: Use it like this:
@deprecated('myfile', 'do something different...') @deprecated('myfile', 'do something different...')
...@@ -107,9 +107,9 @@ def uniq(seq): ...@@ -107,9 +107,9 @@ def uniq(seq):
def difference(seq1, seq2): def difference(seq1, seq2):
""" """
Returns all elements in seq1 which are not in seq2: i.e seq1\seq2 Returns all elements in seq1 which are not in seq2: i.e seq1\seq2
""" """
try: try:
# try to use O(const * len(seq1)) algo # try to use O(const * len(seq1)) algo
if len(seq2) < 4: # I'm guessing this threshold -JB if len(seq2) < 4: # I'm guessing this threshold -JB
raise Exception('not worth it') raise Exception('not worth it')
...@@ -131,7 +131,7 @@ def partition(f, seq): ...@@ -131,7 +131,7 @@ def partition(f, seq):
else: else:
seqf.append(elem) seqf.append(elem)
return seqt, seqf return seqt, seqf
def attr_checker(*attrs): def attr_checker(*attrs):
def f(candidate): def f(candidate):
for attr in attrs: for attr in attrs:
...@@ -204,7 +204,7 @@ def toposort(prereqs_d): ...@@ -204,7 +204,7 @@ def toposort(prereqs_d):
# for x, y in prereqs_d.items(): # for x, y in prereqs_d.items():
# all2.update(y) # all2.update(y)
# print all1.difference(all2) # print all1.difference(all2)
seq = [] seq = []
done = set() done = set()
postreqs_d = {} postreqs_d = {}
...@@ -330,3 +330,27 @@ def flatten(a): ...@@ -330,3 +330,27 @@ def flatten(a):
return l return l
else: else:
return [a] return [a]
def unique(x):
return len(set(x)) == len(x)
def hist(coll):
counts = {}
for elem in coll:
counts[elem] = counts.get(elem, 0) + 1
return counts
def give_variables_names(variables):
""" Gives unique names to an iterable of variables. Modifies input."""
names = map(lambda var: var.name, variables)
h = hist(names)
bad_var = lambda var: h[var.name] > 1
for i, var in enumerate(filter(bad_var, variables)):
var.name = (var.name or "") + "_%d"%i
if not unique(map(str, variables)):
raise ValueError("Not all variables have unique names."
"Maybe you've named some of the variables identically")
return variables
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论