提交 b2696360 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Specialize Zero Alloc

上级 528b8d4b
......@@ -1634,6 +1634,14 @@ class Alloc(COp):
if v_static_dim is None and value_dim == 1 and out_dim != 1:
raise ValueError(Alloc._runtime_broadcast_error_msg)
@staticmethod
def value_is_scalar_zero(x: TensorVariable) -> bool:
return (
all(x.type.broadcastable)
and isinstance(x, Constant)
and (x.unique_value == 0)
)
def perform(self, node, inputs, out_):
(out,) = out_
v = inputs[0]
......@@ -1659,6 +1667,7 @@ class Alloc(COp):
o_static_shape = node.outputs[0].type.shape
v_ndim = len(v_static_shape)
o_ndim = len(o_static_shape)
is_zero = self.value_is_scalar_zero(node.inputs[0])
assert o_ndim == len(inp[1:])
# Declare variables
......@@ -1699,16 +1708,18 @@ class Alloc(COp):
{fail}
}}
}}
if ({int(is_zero)} && (PyArray_IS_C_CONTIGUOUS({zz}) || PyArray_IS_F_CONTIGUOUS({zz}))){{
PyArray_FILLWBYTE({zz}, 0);
}}
// This function takes care of broadcasting
if (PyArray_CopyInto({zz}, {vv}) == -1)
else if (PyArray_CopyInto({zz}, {vv}) == -1)
{fail}
"""
return code
def c_code_cache_version(self):
return (4,)
return (5,)
def infer_shape(self, fgraph, node, input_shapes):
return [node.inputs[1:]]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论