提交 e378dfe4 authored 作者: notoraptor's avatar notoraptor

Update doc for c_support_code. Simplify default's Op.get_params().

上级 17835f40
......@@ -152,7 +152,7 @@ class CLinkerObject(object):
def c_support_code(self):
"""
Optional: Return utility code for use by a `Variable` or `Op` to be
Optional: Return utility code (a string, or a LIST of strings) for use by a `Variable` or `Op` to be
included at global scope prior to the rest of the code for this class.
QUESTION: How many times will this support code be emitted for a graph
......@@ -802,12 +802,10 @@ class Op(utils.object2, PureOp, CLinkerOp):
if hasattr(self, 'params_type') and isinstance(self.params_type, theano.gof.ParamsType):
wrapper = self.params_type
if not all(hasattr(self, field) for field in wrapper.fields):
raise AttributeError('%s: missing attributes for ParamsType parameter.' % type(self).__name__)
wrap_dict = dict()
for i in range(wrapper.length):
field = wrapper.fields[i]
_type = wrapper.types[i]
wrap_dict[field] = _type.filter(getattr(self, field), strict=False, allow_downcast=True)
# Let's print missing attributes for debugging.
not_found = tuple(field for field in wrapper.fields if not hasattr(self, field))
raise AttributeError('%s: missing attributes %s for ParamsType.' % (type(self).__name__, not_found))
# ParamsType.get_params() will apply filtering to attributes.
return self.params_type.get_params(self)
raise theano.gof.utils.MethodNotDefined('get_params')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论