提交 5493b1fd authored 作者: notoraptor's avatar notoraptor

Import Wrapper and Wrap at theano.gof initialization (fix Jenkins issue).

上级 e4e69c4a
...@@ -80,6 +80,8 @@ from theano.gof.type import \ ...@@ -80,6 +80,8 @@ from theano.gof.type import \
from theano.gof.utils import \ from theano.gof.utils import \
hashtype, object2, MethodNotDefined hashtype, object2, MethodNotDefined
from theano.gof.wrapper import Wrapper, Wrap
import theano import theano
if theano.config.cmodule.preload_cache: if theano.config.cmodule.preload_cache:
......
...@@ -799,7 +799,7 @@ class Op(utils.object2, PureOp, CLinkerOp): ...@@ -799,7 +799,7 @@ class Op(utils.object2, PureOp, CLinkerOp):
# We add a default get_params() implementation which will try to detect params from the op # We add a default get_params() implementation which will try to detect params from the op
# if params_type is set to a Wrapper. If not, we raise a MethodNotDefined exception. # if params_type is set to a Wrapper. If not, we raise a MethodNotDefined exception.
def get_params(self, node): def get_params(self, node):
if hasattr(self, 'params_type') and isinstance(self.params_type, theano.gof.wrapper.Wrapper): if hasattr(self, 'params_type') and isinstance(self.params_type, theano.gof.Wrapper):
wrapper = self.params_type wrapper = self.params_type
if all(hasattr(self, field) for field in wrapper.fields): if all(hasattr(self, field) for field in wrapper.fields):
wrap_dict = dict() wrap_dict = dict()
...@@ -807,7 +807,7 @@ class Op(utils.object2, PureOp, CLinkerOp): ...@@ -807,7 +807,7 @@ class Op(utils.object2, PureOp, CLinkerOp):
field = wrapper.fields[i] field = wrapper.fields[i]
_type = wrapper.types[i] _type = wrapper.types[i]
wrap_dict[field] = _type.filter(getattr(self, field), strict=False, allow_downcast=True) wrap_dict[field] = _type.filter(getattr(self, field), strict=False, allow_downcast=True)
return theano.gof.wrapper.Wrap(wrapper, **wrap_dict) return theano.gof.Wrap(wrapper, **wrap_dict)
raise theano.gof.utils.MethodNotDefined('get_params') raise theano.gof.utils.MethodNotDefined('get_params')
def prepare_node(self, node, storage_map, compute_map, impl): def prepare_node(self, node, storage_map, compute_map, impl):
...@@ -1392,7 +1392,7 @@ class COp(Op): ...@@ -1392,7 +1392,7 @@ class COp(Op):
The names must be strings that are not a C keyword and the The names must be strings that are not a C keyword and the
values must be strings of literal C representations. values must be strings of literal C representations.
If op uses a :class:`theano.gof.wrapper.Wrapper` as ``params_type``, If op uses a :class:`theano.gof.Wrapper` as ``params_type``,
it returns: it returns:
- a default macro ``APPLY_SPECIFIC_WRAPPER`` which defines the class name of the - a default macro ``APPLY_SPECIFIC_WRAPPER`` which defines the class name of the
corresponding C struct. corresponding C struct.
...@@ -1402,7 +1402,7 @@ class COp(Op): ...@@ -1402,7 +1402,7 @@ class COp(Op):
associated to ``key``. associated to ``key``.
""" """
if hasattr(self, 'params_type') and isinstance(self.params_type, theano.gof.wrapper.Wrapper): if hasattr(self, 'params_type') and isinstance(self.params_type, theano.gof.Wrapper):
wrapper = self.params_type wrapper = self.params_type
params = [('APPLY_SPECIFIC_WRAPPER', wrapper.name)] params = [('APPLY_SPECIFIC_WRAPPER', wrapper.name)]
for i in range(wrapper.length): for i in range(wrapper.length):
......
...@@ -6,7 +6,7 @@ from theano.gof import Op, COp, Apply ...@@ -6,7 +6,7 @@ from theano.gof import Op, COp, Apply
from theano import Generic from theano import Generic
from theano.scalar import Scalar from theano.scalar import Scalar
from theano.tensor import TensorType from theano.tensor import TensorType
from theano.gof.wrapper import Wrapper, Wrap from theano.gof import Wrapper, Wrap
from theano import tensor from theano import tensor
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
......
...@@ -13,7 +13,7 @@ Importation: ...@@ -13,7 +13,7 @@ Importation:
.. code-block:: python .. code-block:: python
from theano.gof.wrapper import Wrapper from theano.gof import Wrapper
In an op you create: In an op you create:
...@@ -71,7 +71,7 @@ class Wrap(dict): ...@@ -71,7 +71,7 @@ class Wrap(dict):
.. code-block:: python .. code-block:: python
from theano.gof.wrapper import * from theano.gof import Wrapper, Wrap
from theano.scalar import Scalar from theano.scalar import Scalar
# You must create a Wrapper first: # You must create a Wrapper first:
wp = Wrapper(attr1=Scalar('int32'), key2=Scalar('float32'), field3=Scalar('int64')) wp = Wrapper(attr1=Scalar('int32'), key2=Scalar('float32'), field3=Scalar('int64'))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论