提交 6c610ea5 authored 作者: Pierre Luc Carrier's avatar Pierre Luc Carrier

Main external function now has non-void arguments + Fixed extenal C example that was out of date.

上级 3a39e221
...@@ -741,22 +741,19 @@ C file named vectorTimesVector.c : ...@@ -741,22 +741,19 @@ C file named vectorTimesVector.c :
#endif #endif
void vector_elemwise_mult_<<<<NODE_NAME_PLACEHOLDER>>>>( void vector_elemwise_mult_<<<<NODE_NAME_PLACEHOLDER>>>>(
npy_%(dtype_x)s* x_ptr, int x_str, DTYPE_INPUT_0* x_ptr, int x_str,
npy_%(dtype_y)s* y_ptr, int y_str, DTYPE_INPUT_1* y_ptr, int y_str,
npy_%(dtype_z)s* z_ptr, int z_str, int nbElements) DTYPE_OUTPUT_0* z_ptr, int z_str, int nbElements)
{ {
for (int i=0; i < nbElements; i++){ for (int i=0; i < nbElements; i++){
z_ptr[i * z_str] = x_ptr[i * x_str] * y_ptr[i * y_str]; z_ptr[i * z_str] = x_ptr[i * x_str] * y_ptr[i * y_str];
} }
} }
int myFunc_<<<<NODE_NAME_PLACEHOLDER>>>>(void* in0, void* in1, int vector_times_vector_<<<<NODE_NAME_PLACEHOLDER>>>>(PyArrayObject* input0,
void** out0) PyArrayObject* input1,
PyArrayObject** output0)
{ {
PyArrayObject* input0 = (PyArrayObject*)in0;
PyArrayObject* input1 = (PyArrayObject*)in1;
PyArrayObject** output0 = (PyArrayObject**)out0;
// Validate that the inputs have the same shape // Validate that the inputs have the same shape
if ( !vector_same_shape(input0, input1)) if ( !vector_same_shape(input0, input1))
{ {
...@@ -798,7 +795,7 @@ C file named vectorTimesVector.c : ...@@ -798,7 +795,7 @@ C file named vectorTimesVector.c :
PyArray_STRIDES(*output0)[0] / ITEMSIZE_OUTPUT_0, PyArray_STRIDES(*output0)[0] / ITEMSIZE_OUTPUT_0,
PyArray_DIMS(input0)[0]); PyArray_DIMS(input0)[0]);
return 0 return 0;
} }
As you can see from this example, the Python and C implementations are nicely As you can see from this example, the Python and C implementations are nicely
......
...@@ -13,6 +13,8 @@ __contact__ = "theano-dev <theano-dev@googlegroups.com>" ...@@ -13,6 +13,8 @@ __contact__ = "theano-dev <theano-dev@googlegroups.com>"
__docformat__ = "restructuredtext en" __docformat__ = "restructuredtext en"
import logging import logging
import numpy
import os
import sys import sys
import warnings import warnings
...@@ -1013,7 +1015,7 @@ class COp(Op): ...@@ -1013,7 +1015,7 @@ class COp(Op):
# function. The argstring will be of format : # function. The argstring will be of format :
# "input0, input1, input2, (void**)&output0, (void**)&output1" # "input0, input1, input2, (void**)&output0, (void**)&output1"
input_arg_str = ", ".join(inp) input_arg_str = ", ".join(inp)
output_arg_str = ", ".join(["(void**)&%s"] * len(out)) % tuple(out) output_arg_str = ", ".join(["&%s"] * len(out)) % tuple(out)
return input_arg_str + ", " + output_arg_str return input_arg_str + ", " + output_arg_str
def get_c_macros(self, node, name): def get_c_macros(self, node, name):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论