提交 d330a6fc authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Move clone_replace tests to tests.graph.test_basic

上级 48ea2d4b
......@@ -4,7 +4,7 @@ from itertools import count
import numpy as np
import pytest
from aesara import shared
from aesara import config, function, shared
from aesara import tensor as at
from aesara.graph.basic import (
Apply,
......@@ -13,6 +13,7 @@ from aesara.graph.basic import (
applys_between,
as_string,
clone,
clone_replace,
equal_computations,
general_toposort,
get_var_by_name,
......@@ -27,9 +28,18 @@ from aesara.graph.basic import (
from aesara.graph.op import HasInnerGraph, Op
from aesara.graph.type import Type
from aesara.tensor.math import max_and_argmax
from aesara.tensor.type import TensorType, iscalars, matrix, scalars
from aesara.tensor.type import (
TensorType,
dvector,
fvector,
iscalars,
matrix,
scalars,
vector,
)
from aesara.tensor.type_other import NoneConst
from aesara.tensor.var import TensorVariable
from tests import unittest_tools as utt
class MyType(Type):
......@@ -537,3 +547,118 @@ def test_get_var_by_name():
(res,) = get_var_by_name([o1, o2], "igo1")
assert res == igo_out_1
class TestCloneReplace:
def test_cloning_no_replace_strict_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
x = vector("x")
y = vector("y")
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(f1, replace=None, strict=True, share_inputs=True)
f2_inp = graph_inputs([f2])
assert z in f2_inp
assert x in f2_inp
assert y in f2_inp
def test_cloning_no_replace_strict_not_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
x = vector("x")
y = vector("y")
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(f1, replace=None, strict=True, share_inputs=False)
f2_inp = graph_inputs([f2])
assert z not in f2_inp
assert x not in f2_inp
assert y not in f2_inp
def test_cloning_replace_strict_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
x = vector("x")
y = vector("y")
y2 = vector("y2")
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(f1, replace={y: y2}, strict=True, share_inputs=True)
f2_inp = graph_inputs([f2])
assert z in f2_inp
assert x in f2_inp
assert y2 in f2_inp
def test_cloning_replace_not_strict_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
x = vector("x")
y = fvector("y")
y2 = dvector("y2")
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(f1, replace={y: y2}, strict=False, share_inputs=True)
f2_inp = graph_inputs([f2])
assert z in f2_inp
assert x in f2_inp
assert y2 in f2_inp
def test_cloning_replace_strict_not_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
x = vector("x")
y = vector("y")
y2 = vector("y2")
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(f1, replace=[(y, y2)], strict=True, share_inputs=False)
f2_inp = graph_inputs([f2])
assert z not in f2_inp
assert x not in f2_inp
assert y2 not in f2_inp
def test_cloning_replace_not_strict_not_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
x = vector("x")
y = fvector("y")
y2 = dvector("y2")
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(f1, replace=[(y, y2)], strict=False, share_inputs=False)
f2_inp = graph_inputs([f2])
assert z not in f2_inp
assert x not in f2_inp
assert y2 not in f2_inp
def test_clone(self):
def test(x, y, mention_y):
if mention_y:
d = 0.1 + 0 * y
else:
d = 0.1
out = clone_replace(y, replace={x: x + d})
return function([], out)()
x = shared(np.asarray(0.0, dtype=config.floatX))
utt.assert_allclose(
test(x, at.sum((x + 1) ** 2), mention_y=False), 1.21000003815
)
utt.assert_allclose(
test(x, at.sum((x + 1) ** 2), mention_y=True), 1.21000003815
)
......@@ -36,7 +36,7 @@ from aesara.gradient import (
hessian,
jacobian,
)
from aesara.graph.basic import Apply, clone_replace, graph_inputs
from aesara.graph.basic import Apply, clone_replace
from aesara.graph.fg import MissingInputError
from aesara.graph.op import Op
from aesara.misc.safe_asarray import _asarray
......@@ -2060,106 +2060,6 @@ class TestScan:
utt.assert_allclose(out, vR)
def test_cloning_no_replace_strict_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
x = vector("x")
y = vector("y")
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(f1, replace=None, strict=True, share_inputs=True)
f2_inp = graph_inputs([f2])
assert z in f2_inp
assert x in f2_inp
assert y in f2_inp
def test_cloning_no_replace_strict_not_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
x = vector("x")
y = vector("y")
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(f1, replace=None, strict=True, share_inputs=False)
f2_inp = graph_inputs([f2])
assert z not in f2_inp
assert x not in f2_inp
assert y not in f2_inp
def test_cloning_replace_strict_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
x = vector("x")
y = vector("y")
y2 = vector("y2")
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(
f1, replace=OrderedDict([(y, y2)]), strict=True, share_inputs=True
)
f2_inp = graph_inputs([f2])
assert z in f2_inp
assert x in f2_inp
assert y2 in f2_inp
def test_cloning_replace_not_strict_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
x = vector("x")
y = fvector("y")
y2 = dvector("y2")
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(
f1, replace=OrderedDict([(y, y2)]), strict=False, share_inputs=True
)
f2_inp = graph_inputs([f2])
assert z in f2_inp
assert x in f2_inp
assert y2 in f2_inp
def test_cloning_replace_strict_not_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
x = vector("x")
y = vector("y")
y2 = vector("y2")
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(f1, replace=[(y, y2)], strict=True, share_inputs=False)
f2_inp = graph_inputs([f2])
assert z not in f2_inp
assert x not in f2_inp
assert y2 not in f2_inp
def test_cloning_replace_not_strict_not_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
x = vector("x")
y = fvector("y")
y2 = dvector("y2")
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(f1, replace=[(y, y2)], strict=False, share_inputs=False)
f2_inp = graph_inputs([f2])
assert z not in f2_inp
assert x not in f2_inp
assert y2 not in f2_inp
# TEST RE-ordering of inputs
# some rnn with multiple outputs and multiple inputs; other
# dimension instead of scalars/vectors
......@@ -4325,23 +4225,6 @@ class TestScan:
f = function([v], grad(y.sum(), W))
utt.assert_allclose(f([1, 2]), [[0, 0, 0], [1, 1, 1], [1, 1, 1]])
def test_clone(self):
def test(x, y, mention_y):
if mention_y:
d = 0.1 + 0 * y
else:
d = 0.1
out = clone_replace(y, replace={x: x + d})
return function([], out)()
x = shared(np.asarray(0.0, dtype=config.floatX))
utt.assert_allclose(
test(x, at_sum((x + 1) ** 2), mention_y=False), 1.21000003815
)
utt.assert_allclose(
test(x, at_sum((x + 1) ** 2), mention_y=True), 1.21000003815
)
def test_grad_find_input(self):
w = shared(np.array(0, dtype="float32"), name="w")
init = fscalar("init")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论