提交 a50cb8be authored 作者: Olivier Breuleux's avatar Olivier Breuleux

docstring and test for setdefault, fixes #427

上级 f4b23d8c
...@@ -1397,8 +1397,15 @@ class Repeat(gof.Op): ...@@ -1397,8 +1397,15 @@ class Repeat(gof.Op):
repeat = Repeat() repeat = Repeat()
class SetDefault(gof.Op): class SetDefault(gof.Op):
view_map = {0: [1]} """
Takes an input x and a default value. If the input is not None, a
reference to it is returned. If the input is None, a copy of the
default value is returned instead. The input and the default must
have exactly the same type.
"""
view_map = {0: [0]}
def make_node(self, x, default): def make_node(self, x, default):
x, default = as_tensor_variable(x), as_tensor_variable(default)
assert x.type == default.type assert x.type == default.type
return gof.Apply(self, [x, default], [default.type()]) return gof.Apply(self, [x, default], [default.type()])
def perform(self, node, (x, default), (out, )): def perform(self, node, (x, default), (out, )):
......
...@@ -1956,8 +1956,24 @@ def test_convert_to_complex(): ...@@ -1956,8 +1956,24 @@ def test_convert_to_complex():
b = value(numpy.ones(3, dtype='complex64')) b = value(numpy.ones(3, dtype='complex64'))
f = function([a],basic.convert_to_complex64(a)) f = function([a],basic.convert_to_complex64(a))
assert a.type.values_eq_approx(b.data, f(a.data)) assert a.type.values_eq_approx(b.data, f(a.data))
def test_setdefault():
x, y = dscalars('xy')
z = setdefault(x, y)
f = function([x, y], z)
assert f(1, 2) == 1
assert f(None, 2) == 2
assert f(1, None) == 1
def test_setdefault_state():
x, y = dscalars('xy')
z = setdefault(x, 3.8)
new_x = y + z
f = function([y, compile.In(x, update = new_x, value = 12.0)], new_x)
assert f(3) == 15
f['x'] = None
assert f(1) == 4.8
assert f(2.2) == 7
def test_bug_complext_10_august_09(): def test_bug_complext_10_august_09():
v0 = dmatrix() v0 = dmatrix()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论