提交 b2696360 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Specialize Zero Alloc

上级 528b8d4b
...@@ -1634,6 +1634,14 @@ class Alloc(COp): ...@@ -1634,6 +1634,14 @@ class Alloc(COp):
if v_static_dim is None and value_dim == 1 and out_dim != 1: if v_static_dim is None and value_dim == 1 and out_dim != 1:
raise ValueError(Alloc._runtime_broadcast_error_msg) raise ValueError(Alloc._runtime_broadcast_error_msg)
@staticmethod
def value_is_scalar_zero(x: TensorVariable) -> bool:
return (
all(x.type.broadcastable)
and isinstance(x, Constant)
and (x.unique_value == 0)
)
def perform(self, node, inputs, out_): def perform(self, node, inputs, out_):
(out,) = out_ (out,) = out_
v = inputs[0] v = inputs[0]
...@@ -1659,6 +1667,7 @@ class Alloc(COp): ...@@ -1659,6 +1667,7 @@ class Alloc(COp):
o_static_shape = node.outputs[0].type.shape o_static_shape = node.outputs[0].type.shape
v_ndim = len(v_static_shape) v_ndim = len(v_static_shape)
o_ndim = len(o_static_shape) o_ndim = len(o_static_shape)
is_zero = self.value_is_scalar_zero(node.inputs[0])
assert o_ndim == len(inp[1:]) assert o_ndim == len(inp[1:])
# Declare variables # Declare variables
...@@ -1699,16 +1708,18 @@ class Alloc(COp): ...@@ -1699,16 +1708,18 @@ class Alloc(COp):
{fail} {fail}
}} }}
}} }}
if ({int(is_zero)} && (PyArray_IS_C_CONTIGUOUS({zz}) || PyArray_IS_F_CONTIGUOUS({zz}))){{
PyArray_FILLWBYTE({zz}, 0);
}}
// This function takes care of broadcasting // This function takes care of broadcasting
if (PyArray_CopyInto({zz}, {vv}) == -1) else if (PyArray_CopyInto({zz}, {vv}) == -1)
{fail} {fail}
""" """
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
return (4,) return (5,)
def infer_shape(self, fgraph, node, input_shapes): def infer_shape(self, fgraph, node, input_shapes):
return [node.inputs[1:]] return [node.inputs[1:]]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论