提交 e2dd22c4 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Remove deprecated and unused functions in theano.gof.test_utils

上级 8e24e782
import theano
from theano.gof.utils import give_variables_names, remove, unique
def test_give_variables_names():
x = theano.tensor.matrix("x")
y = x + 1
z = theano.tensor.dot(x, y)
variables = (x, y, z)
give_variables_names(variables)
assert all(var.name for var in variables)
assert unique([var.name for var in variables])
def test_give_variables_names_idempotence():
x = theano.tensor.matrix("x")
y = x + 1
z = theano.tensor.dot(x, y)
variables = (x, y, z)
give_variables_names(variables)
names = [var.name for var in variables]
give_variables_names(variables)
names2 = [var.name for var in variables]
assert names == names2
def test_give_variables_names_small():
x = theano.tensor.matrix("x")
y = theano.tensor.dot(x, x)
fgraph = theano.FunctionGraph((x,), (y,))
give_variables_names(fgraph.variables)
assert all(var.name for var in fgraph.variables)
assert unique([var.name for var in fgraph.variables])
from theano.gof.utils import remove
def test_remove():
......
......@@ -3,7 +3,6 @@ import linecache
import sys
import traceback
from io import StringIO
from warnings import warn
from theano import config
......@@ -377,39 +376,6 @@ def memoize(f):
return rval
def deprecated(filename, msg=""):
"""
Decorator which will print a warning message on the first call.
Use it like this::
@deprecated('myfile', 'do something different...')
def fn_name(...)
...
And it will print::
WARNING myfile.fn_name deprecated. do something different...
"""
def _deprecated(f):
printme = [True]
def g(*args, **kwargs):
if printme[0]:
warn(
f"{filename}.{f.__name__} deprecated. {msg}",
category=DeprecationWarning,
)
printme[0] = False
return f(*args, **kwargs)
return g
return _deprecated
def uniq(seq):
"""
Do not use set, this must always return the same value at the same index.
......@@ -621,10 +587,6 @@ def flatten(a):
return [a]
def unique(x):
return len(set(x)) == len(x)
def hist(coll):
counts = {}
for elem in coll:
......@@ -632,31 +594,6 @@ def hist(coll):
return counts
@deprecated("theano.gof.utils", msg="Use a_theano_variable.auto_name instead")
def give_variables_names(variables):
"""
Gives unique names to an iterable of variables. Modifies input.
This function is idempotent.
"""
names = [var.name for var in variables]
h = hist(names)
def bad_var(var):
return not var.name or h[var.name] > 1
for i, var in enumerate(filter(bad_var, variables)):
var.name = (var.name or "") + f"_{int(i)}"
if not unique([str(v) for v in variables]):
raise ValueError(
"Not all variables have unique names. Maybe you've "
"named some of the variables identically"
)
return variables
def remove(predicate, coll):
"""
Return those items of collection for which predicate(item) is true.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论