提交 50c5e6c1 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

changed the name of setdefault op to default

上级 a50cb8be
...@@ -1396,7 +1396,7 @@ class Repeat(gof.Op): ...@@ -1396,7 +1396,7 @@ class Repeat(gof.Op):
repeat = Repeat() repeat = Repeat()
class SetDefault(gof.Op): class Default(gof.Op):
""" """
Takes an input x and a default value. If the input is not None, a Takes an input x and a default value. If the input is not None, a
reference to it is returned. If the input is None, a copy of the reference to it is returned. If the input is None, a copy of the
...@@ -1411,7 +1411,8 @@ class SetDefault(gof.Op): ...@@ -1411,7 +1411,8 @@ class SetDefault(gof.Op):
def perform(self, node, (x, default), (out, )): def perform(self, node, (x, default), (out, )):
out[0] = default.copy() if x is None else x out[0] = default.copy() if x is None else x
setdefault = SetDefault() default = Default()
setdefault = default # legacy
########################## ##########################
......
...@@ -1957,17 +1957,17 @@ def test_convert_to_complex(): ...@@ -1957,17 +1957,17 @@ def test_convert_to_complex():
f = function([a],basic.convert_to_complex64(a)) f = function([a],basic.convert_to_complex64(a))
assert a.type.values_eq_approx(b.data, f(a.data)) assert a.type.values_eq_approx(b.data, f(a.data))
def test_setdefault(): def test_default():
x, y = dscalars('xy') x, y = dscalars('xy')
z = setdefault(x, y) z = default(x, y)
f = function([x, y], z) f = function([x, y], z)
assert f(1, 2) == 1 assert f(1, 2) == 1
assert f(None, 2) == 2 assert f(None, 2) == 2
assert f(1, None) == 1 assert f(1, None) == 1
def test_setdefault_state(): def test_default_state():
x, y = dscalars('xy') x, y = dscalars('xy')
z = setdefault(x, 3.8) z = default(x, 3.8)
new_x = y + z new_x = y + z
f = function([y, compile.In(x, update = new_x, value = 12.0)], new_x) f = function([y, compile.In(x, update = new_x, value = 12.0)], new_x)
assert f(3) == 15 assert f(3) == 15
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论