提交 12765623 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add a C implementation for TensorFromScalar

上级 9f176da7
...@@ -528,7 +528,7 @@ def get_scalar_constant_value( ...@@ -528,7 +528,7 @@ def get_scalar_constant_value(
raise NotScalarConstantError() raise NotScalarConstantError()
class TensorFromScalar(Op): class TensorFromScalar(COp):
__props__ = () __props__ = ()
...@@ -562,6 +562,25 @@ class TensorFromScalar(Op): ...@@ -562,6 +562,25 @@ class TensorFromScalar(Op):
raise NotImplementedError("grad not implemented for complex dtypes") raise NotImplementedError("grad not implemented for complex dtypes")
def c_code(self, node, name, inputs, outputs, sub):
(x,) = inputs
(z,) = outputs
fail = sub["fail"]
return (
"""
%(z)s = (PyArrayObject*)PyArray_FromScalar(py_%(x)s, NULL);
if(py_%(z)s == NULL){
%(fail)s;
}
Py_XINCREF(%(z)s);
"""
% locals()
)
def c_code_cache_version(self):
return (1,)
tensor_from_scalar = TensorFromScalar() tensor_from_scalar = TensorFromScalar()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论