提交 d330a6fc authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Move clone_replace tests to tests.graph.test_basic

上级 48ea2d4b
...@@ -4,7 +4,7 @@ from itertools import count ...@@ -4,7 +4,7 @@ from itertools import count
import numpy as np import numpy as np
import pytest import pytest
from aesara import shared from aesara import config, function, shared
from aesara import tensor as at from aesara import tensor as at
from aesara.graph.basic import ( from aesara.graph.basic import (
Apply, Apply,
...@@ -13,6 +13,7 @@ from aesara.graph.basic import ( ...@@ -13,6 +13,7 @@ from aesara.graph.basic import (
applys_between, applys_between,
as_string, as_string,
clone, clone,
clone_replace,
equal_computations, equal_computations,
general_toposort, general_toposort,
get_var_by_name, get_var_by_name,
...@@ -27,9 +28,18 @@ from aesara.graph.basic import ( ...@@ -27,9 +28,18 @@ from aesara.graph.basic import (
from aesara.graph.op import HasInnerGraph, Op from aesara.graph.op import HasInnerGraph, Op
from aesara.graph.type import Type from aesara.graph.type import Type
from aesara.tensor.math import max_and_argmax from aesara.tensor.math import max_and_argmax
from aesara.tensor.type import TensorType, iscalars, matrix, scalars from aesara.tensor.type import (
TensorType,
dvector,
fvector,
iscalars,
matrix,
scalars,
vector,
)
from aesara.tensor.type_other import NoneConst from aesara.tensor.type_other import NoneConst
from aesara.tensor.var import TensorVariable from aesara.tensor.var import TensorVariable
from tests import unittest_tools as utt
class MyType(Type): class MyType(Type):
...@@ -537,3 +547,118 @@ def test_get_var_by_name(): ...@@ -537,3 +547,118 @@ def test_get_var_by_name():
(res,) = get_var_by_name([o1, o2], "igo1") (res,) = get_var_by_name([o1, o2], "igo1")
assert res == igo_out_1 assert res == igo_out_1
class TestCloneReplace:
def test_cloning_no_replace_strict_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
x = vector("x")
y = vector("y")
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(f1, replace=None, strict=True, share_inputs=True)
f2_inp = graph_inputs([f2])
assert z in f2_inp
assert x in f2_inp
assert y in f2_inp
def test_cloning_no_replace_strict_not_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
x = vector("x")
y = vector("y")
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(f1, replace=None, strict=True, share_inputs=False)
f2_inp = graph_inputs([f2])
assert z not in f2_inp
assert x not in f2_inp
assert y not in f2_inp
def test_cloning_replace_strict_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
x = vector("x")
y = vector("y")
y2 = vector("y2")
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(f1, replace={y: y2}, strict=True, share_inputs=True)
f2_inp = graph_inputs([f2])
assert z in f2_inp
assert x in f2_inp
assert y2 in f2_inp
def test_cloning_replace_not_strict_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
x = vector("x")
y = fvector("y")
y2 = dvector("y2")
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(f1, replace={y: y2}, strict=False, share_inputs=True)
f2_inp = graph_inputs([f2])
assert z in f2_inp
assert x in f2_inp
assert y2 in f2_inp
def test_cloning_replace_strict_not_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
x = vector("x")
y = vector("y")
y2 = vector("y2")
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(f1, replace=[(y, y2)], strict=True, share_inputs=False)
f2_inp = graph_inputs([f2])
assert z not in f2_inp
assert x not in f2_inp
assert y2 not in f2_inp
def test_cloning_replace_not_strict_not_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
x = vector("x")
y = fvector("y")
y2 = dvector("y2")
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(f1, replace=[(y, y2)], strict=False, share_inputs=False)
f2_inp = graph_inputs([f2])
assert z not in f2_inp
assert x not in f2_inp
assert y2 not in f2_inp
def test_clone(self):
def test(x, y, mention_y):
if mention_y:
d = 0.1 + 0 * y
else:
d = 0.1
out = clone_replace(y, replace={x: x + d})
return function([], out)()
x = shared(np.asarray(0.0, dtype=config.floatX))
utt.assert_allclose(
test(x, at.sum((x + 1) ** 2), mention_y=False), 1.21000003815
)
utt.assert_allclose(
test(x, at.sum((x + 1) ** 2), mention_y=True), 1.21000003815
)
...@@ -36,7 +36,7 @@ from aesara.gradient import ( ...@@ -36,7 +36,7 @@ from aesara.gradient import (
hessian, hessian,
jacobian, jacobian,
) )
from aesara.graph.basic import Apply, clone_replace, graph_inputs from aesara.graph.basic import Apply, clone_replace
from aesara.graph.fg import MissingInputError from aesara.graph.fg import MissingInputError
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
...@@ -2060,106 +2060,6 @@ class TestScan: ...@@ -2060,106 +2060,6 @@ class TestScan:
utt.assert_allclose(out, vR) utt.assert_allclose(out, vR)
def test_cloning_no_replace_strict_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
x = vector("x")
y = vector("y")
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(f1, replace=None, strict=True, share_inputs=True)
f2_inp = graph_inputs([f2])
assert z in f2_inp
assert x in f2_inp
assert y in f2_inp
def test_cloning_no_replace_strict_not_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
x = vector("x")
y = vector("y")
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(f1, replace=None, strict=True, share_inputs=False)
f2_inp = graph_inputs([f2])
assert z not in f2_inp
assert x not in f2_inp
assert y not in f2_inp
def test_cloning_replace_strict_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
x = vector("x")
y = vector("y")
y2 = vector("y2")
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(
f1, replace=OrderedDict([(y, y2)]), strict=True, share_inputs=True
)
f2_inp = graph_inputs([f2])
assert z in f2_inp
assert x in f2_inp
assert y2 in f2_inp
def test_cloning_replace_not_strict_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
x = vector("x")
y = fvector("y")
y2 = dvector("y2")
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(
f1, replace=OrderedDict([(y, y2)]), strict=False, share_inputs=True
)
f2_inp = graph_inputs([f2])
assert z in f2_inp
assert x in f2_inp
assert y2 in f2_inp
def test_cloning_replace_strict_not_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
x = vector("x")
y = vector("y")
y2 = vector("y2")
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(f1, replace=[(y, y2)], strict=True, share_inputs=False)
f2_inp = graph_inputs([f2])
assert z not in f2_inp
assert x not in f2_inp
assert y2 not in f2_inp
def test_cloning_replace_not_strict_not_copy_inputs(self):
# This has nothing to do with scan, but it refers to the clone
# function that scan uses internally and that pfunc uses now and
# that users might want to use
x = vector("x")
y = fvector("y")
y2 = dvector("y2")
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(f1, replace=[(y, y2)], strict=False, share_inputs=False)
f2_inp = graph_inputs([f2])
assert z not in f2_inp
assert x not in f2_inp
assert y2 not in f2_inp
# TEST RE-ordering of inputs # TEST RE-ordering of inputs
# some rnn with multiple outputs and multiple inputs; other # some rnn with multiple outputs and multiple inputs; other
# dimension instead of scalars/vectors # dimension instead of scalars/vectors
...@@ -4325,23 +4225,6 @@ class TestScan: ...@@ -4325,23 +4225,6 @@ class TestScan:
f = function([v], grad(y.sum(), W)) f = function([v], grad(y.sum(), W))
utt.assert_allclose(f([1, 2]), [[0, 0, 0], [1, 1, 1], [1, 1, 1]]) utt.assert_allclose(f([1, 2]), [[0, 0, 0], [1, 1, 1], [1, 1, 1]])
def test_clone(self):
def test(x, y, mention_y):
if mention_y:
d = 0.1 + 0 * y
else:
d = 0.1
out = clone_replace(y, replace={x: x + d})
return function([], out)()
x = shared(np.asarray(0.0, dtype=config.floatX))
utt.assert_allclose(
test(x, at_sum((x + 1) ** 2), mention_y=False), 1.21000003815
)
utt.assert_allclose(
test(x, at_sum((x + 1) ** 2), mention_y=True), 1.21000003815
)
def test_grad_find_input(self): def test_grad_find_input(self):
w = shared(np.array(0, dtype="float32"), name="w") w = shared(np.array(0, dtype="float32"), name="w")
init = fscalar("init") init = fscalar("init")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论