提交 2692a34a authored 作者: Matthew Rocklin's avatar Matthew Rocklin

add function to give variables unique names

上级 00183e72
import theano
from theano.gof.utils import give_variables_names, unique
def test_variables_with_names():
x = theano.tensor.matrix('x')
y = x + 1
z = theano.tensor.dot(x, y)
variables = {x,y,z}
give_variables_names(variables)
assert all(var.name for var in variables)
assert unique([var.name for var in variables])
......@@ -7,7 +7,7 @@ import re, os, traceback
def add_tag_trace(thing):
"""Add tag.trace to an node or variable.
The argument is returned after being affected (inplace).
"""
limit = config.traceback.limit
......@@ -20,7 +20,7 @@ def hashgen():
return hashgen.next
hashgen.next = 0
class MethodNotDefined(Exception):
class MethodNotDefined(Exception):
"""
To be raised by functions defined as part of an interface.
......@@ -78,7 +78,7 @@ def memoize(f):
def deprecated(filename, msg=''):
"""Decorator which will print a warning message on the first call.
Use it like this:
@deprecated('myfile', 'do something different...')
......@@ -107,9 +107,9 @@ def uniq(seq):
def difference(seq1, seq2):
"""
Returns all elements in seq1 which are not in seq2: i.e seq1\seq2
Returns all elements in seq1 which are not in seq2: i.e seq1\seq2
"""
try:
try:
# try to use O(const * len(seq1)) algo
if len(seq2) < 4: # I'm guessing this threshold -JB
raise Exception('not worth it')
......@@ -131,7 +131,7 @@ def partition(f, seq):
else:
seqf.append(elem)
return seqt, seqf
def attr_checker(*attrs):
def f(candidate):
for attr in attrs:
......@@ -204,7 +204,7 @@ def toposort(prereqs_d):
# for x, y in prereqs_d.items():
# all2.update(y)
# print all1.difference(all2)
seq = []
done = set()
postreqs_d = {}
......@@ -330,3 +330,27 @@ def flatten(a):
return l
else:
return [a]
def unique(x):
return len(set(x)) == len(x)
def hist(coll):
counts = {}
for elem in coll:
counts[elem] = counts.get(elem, 0) + 1
return counts
def give_variables_names(variables):
""" Gives unique names to an iterable of variables. Modifies input."""
names = map(lambda var: var.name, variables)
h = hist(names)
bad_var = lambda var: h[var.name] > 1
for i, var in enumerate(filter(bad_var, variables)):
var.name = (var.name or "") + "_%d"%i
if not unique(map(str, variables)):
raise ValueError("Not all variables have unique names."
"Maybe you've named some of the variables identically")
return variables
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论