提交 9815eca2 authored 作者: nouiz's avatar nouiz

Merge pull request #923 from mrocklin/give_variables_names

add function to give variables unique names
import theano
from theano.gof.utils import give_variables_names, unique
def test_give_variables_names():
x = theano.tensor.matrix('x')
y = x + 1
z = theano.tensor.dot(x, y)
variables = (x, y, z)
give_variables_names(variables)
assert all(var.name for var in variables)
assert unique([var.name for var in variables])
def test_give_variables_names_idempotence():
x = theano.tensor.matrix('x')
y = x + 1
z = theano.tensor.dot(x, y)
variables = (x, y, z)
give_variables_names(variables)
names = [var.name for var in variables]
give_variables_names(variables)
names2 = [var.name for var in variables]
assert names == names2
def test_give_variables_names_small():
x = theano.tensor.matrix('x')
y = theano.tensor.dot(x, x)
fgraph = theano.FunctionGraph((x,), (y,))
give_variables_names(fgraph.variables)
assert all(var.name for var in fgraph.variables)
assert unique([var.name for var in fgraph.variables])
...@@ -330,3 +330,29 @@ def flatten(a): ...@@ -330,3 +330,29 @@ def flatten(a):
return l return l
else: else:
return [a] return [a]
def unique(x):
return len(set(x)) == len(x)
def hist(coll):
counts = {}
for elem in coll:
counts[elem] = counts.get(elem, 0) + 1
return counts
def give_variables_names(variables):
""" Gives unique names to an iterable of variables. Modifies input.
This function is idempotent."""
names = map(lambda var: var.name, variables)
h = hist(names)
bad_var = lambda var: not var.name or h[var.name] > 1
for i, var in enumerate(filter(bad_var, variables)):
var.name = (var.name or "") + "_%d"%i
if not unique(map(str, variables)):
raise ValueError("Not all variables have unique names."
"Maybe you've named some of the variables identically")
return variables
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论