提交 5b50d27d authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Allow inlined Infinity / Nan constants in Composite

上级 6d5cc513
......@@ -432,6 +432,12 @@ class ScalarType(CType, HasDataType, HasShape):
return None
if self.dtype == "bool":
return "1" if data else "0"
if data == np.inf:
return "INFINITY"
if data == -np.inf:
return "-INFINITY"
if np.isnan(data):
return "NAN"
return str(data)
def c_declare(self, name, sub, check_input=True):
......
......@@ -128,21 +128,33 @@ class TestComposite:
# We don't flatten that case.
assert isinstance(CC.outputs[0].owner.op, Composite)
def test_with_constants(self):
@pytest.mark.parametrize("literal_value", (70.0, -np.inf, np.float32("nan")))
def test_with_constants(self, literal_value):
x, y, z = floats("xyz")
e = mul(add(70.0, y), true_div(x, y))
e = mul(add(literal_value, y), true_div(x, y))
comp_op = Composite([x, y], [e])
comp_node = comp_op.make_node(x, y)
c_code = comp_node.op.c_code(comp_node, "dummy", ["x", "y"], ["z"], dict(id=0))
assert "70.0" in c_code
assert constant(literal_value).type.c_literal(literal_value) in c_code
# Make sure caching of the c_code template works
assert hasattr(comp_node.op, "_c_code")
g = FunctionGraph([x, y], [comp_node.out])
fn = make_function(DualLinker().accept(g))
assert fn(1.0, 2.0) == 36.0
# Default checker does not allow `nan`
def checker(x, y):
np.testing.assert_equal(x, y)
fn = make_function(DualLinker(checker=checker).accept(g))
test_x = 1.0
test_y = 2.0
np.testing.assert_equal(
fn(test_x, test_y),
(literal_value + test_y) * (test_x / test_y),
)
def test_many_outputs(self):
x, y, z = floats("xyz")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论