提交 a724842f authored 作者: notoraptor's avatar notoraptor

Update default `get_op_params()` implementation.

Now, if `params_type` is a Wrapper, it will generate a default macro `APPLY_SPECIFIC_WRAPPER`, and a macro `DTYPE_PARAM_key` for every key in the Wrapper for which associated type implements method `c_element_type()`.
上级 22d04f87
......@@ -1392,7 +1392,24 @@ class COp(Op):
The names must be strings that are not a C keyword and the
values must be strings of literal C representations.
If op uses a :class:`theano.gof.wrapper.Wrapper` as ``params_type``,
it returns:
- a default macro ``APPLY_SPECIFIC_WRAPPER`` which defines the class name of the
corresponding C struct.
- a macro ``DTYPE_PARAM_key`` for every ``key`` in the Wrapper for which associated
type implements the method :func:`theano.gof.type.CLinkerType.c_element_type`.
``DTYPE_PARAM_key`` defines the primitive C type name of an item in a variable
associated to ``key``.
"""
if hasattr(self, 'params_type') and isinstance(self.params_type, theano.gof.wrapper.Wrapper):
wrapper = self.params_type
params = [('APPLY_SPECIFIC_WRAPPER', wrapper.name)]
for i in range(wrapper.length):
field_c_element_type = wrapper.types[i].c_element_type()
if field_c_element_type:
params.append(('DTYPE_PARAM_' + wrapper.fields[i], field_c_element_type))
return params
return []
def c_code_cache_version(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论