Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
f35aa65b
提交
f35aa65b
authored
1月 03, 2021
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
1月 09, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Rename PureType to Type and Type to CType
上级
6c6d81c6
显示空白字符变更
内嵌
并排
正在显示
32 个修改的文件
包含
298 行增加
和
319 行删除
+298
-319
cop.txt
doc/extending/cop.txt
+5
-5
ctype.txt
doc/extending/ctype.txt
+20
-23
extending_theano.txt
doc/extending/extending_theano.txt
+3
-3
op.txt
doc/extending/op.txt
+2
-2
other_ops.txt
doc/extending/other_ops.txt
+2
-2
type.txt
doc/extending/type.txt
+11
-12
using_params.txt
doc/extending/using_params.txt
+4
-4
test_cc.py
tests/link/test_cc.py
+2
-2
__init__.py
theano/__init__.py
+1
-1
compiledir.py
theano/compile/compiledir.py
+2
-2
ops.py
theano/compile/ops.py
+3
-2
__init__.py
theano/gof/__init__.py
+1
-1
graph.py
theano/gof/graph.py
+34
-35
null_type.py
theano/gof/null_type.py
+2
-2
params_type.py
theano/gof/params_type.py
+12
-10
type.py
theano/gof/type.py
+13
-58
basic_ops.py
theano/gpuarray/basic_ops.py
+2
-2
subtensor.py
theano/gpuarray/subtensor.py
+4
-3
type.py
theano/gpuarray/type.py
+4
-4
gradient.py
theano/gradient.py
+1
-1
basic.py
theano/link/basic.py
+2
-2
basic.py
theano/link/c/basic.py
+40
-48
jax_dispatch.py
theano/link/jax/jax_dispatch.py
+2
-2
basic.py
theano/scalar/basic.py
+3
-3
type.py
theano/sparse/type.py
+1
-1
basic.py
theano/tensor/basic.py
+4
-4
type.py
theano/tensor/random/type.py
+2
-2
subtensor.py
theano/tensor/subtensor.py
+28
-28
type.py
theano/tensor/type.py
+2
-2
type_other.py
theano/tensor/type_other.py
+2
-2
type.py
theano/typed_list/type.py
+2
-2
utils.py
theano/utils.py
+82
-49
没有找到文件。
doc/extending/cop.txt
浏览文件 @
f35aa65b
...
@@ -25,7 +25,7 @@ input variables and place the variables in the output variables.
...
@@ -25,7 +25,7 @@ input variables and place the variables in the output variables.
What needs to be defined
What needs to be defined
========================
========================
There are less methods to define for a
n `COp` than for a Type
:
There are less methods to define for a
`COp` than for a `Type`
:
.. class:: COp
.. class:: COp
...
@@ -213,9 +213,9 @@ There are less methods to define for an `COp` than for a Type:
...
@@ -213,9 +213,9 @@ There are less methods to define for an `COp` than for a Type:
Optional. If present this method will be called before doing
Optional. If present this method will be called before doing
constant folding of a node, with that node as a parameter. If
constant folding of a node, with that node as a parameter. If
it return True, we will not generate
c
code when doing constant
it return True, we will not generate
C
code when doing constant
folding of this node. This is useful when the compilation of
folding of this node. This is useful when the compilation of
the
c
code will be longer then the computation in python
the
C
code will be longer then the computation in python
(e.g. Elemwise of scalars).
(e.g. Elemwise of scalars).
In addition, this allow to lower the number of compiled module
In addition, this allow to lower the number of compiled module
...
@@ -233,7 +233,7 @@ There are less methods to define for an `COp` than for a Type:
...
@@ -233,7 +233,7 @@ There are less methods to define for an `COp` than for a Type:
considered the same as if the method was not defined.
considered the same as if the method was not defined.
If this method is defined and does not return `None`, then the
If this method is defined and does not return `None`, then the
Op *must* have a `params_type` property with the Type
to use
`Op` *must* have a `params_type` property with the `Type`
to use
for the params variable.
for the params variable.
.. attribute:: _f16_ok
.. attribute:: _f16_ok
...
@@ -252,7 +252,7 @@ There are less methods to define for an `COp` than for a Type:
...
@@ -252,7 +252,7 @@ There are less methods to define for an `COp` than for a Type:
developpment if a better solution is found.
developpment if a better solution is found.
The ``name`` argument is currently given an invalid value, so steer
The ``name`` argument is currently given an invalid value, so steer
away from it. As was the case with
Type
, ``sub['fail']`` provides
away from it. As was the case with
`Type`
, ``sub['fail']`` provides
failure code that you *must* use if you want to raise an exception,
failure code that you *must* use if you want to raise an exception,
after setting the exception message.
after setting the exception message.
...
...
doc/extending/ctype.txt
浏览文件 @
f35aa65b
...
@@ -6,7 +6,7 @@ Implementing double in C
...
@@ -6,7 +6,7 @@ Implementing double in C
========================
========================
The previous two sections described how to define a double :ref:`type`
The previous two sections described how to define a double :ref:`type`
and arithmetic operations on that
Type
, but all of them were
and arithmetic operations on that
`Type`
, but all of them were
implemented in pure Python. In this section we will see how to define
implemented in pure Python. In this section we will see how to define
the double type in such a way that it can be used by operations
the double type in such a way that it can be used by operations
implemented in C (which we will define in the section after that).
implemented in C (which we will define in the section after that).
...
@@ -15,15 +15,15 @@ implemented in C (which we will define in the section after that).
...
@@ -15,15 +15,15 @@ implemented in C (which we will define in the section after that).
How does it work?
How does it work?
=================
=================
In order to be C-compatible, a
Type
must provide a C interface to the
In order to be C-compatible, a
`Type`
must provide a C interface to the
Python data that satisfy the constraints it puts forward. In other
Python data that satisfy the constraints it puts forward. In other
words, it must define C code that can convert a Python reference into
words, it must define C code that can convert a Python reference into
some type suitable for manipulation in C and it must define C code
some type suitable for manipulation in C and it must define C code
that can convert some C structure in which the C implementation of an
that can convert some C structure in which the C implementation of an
operation stores its variables into a reference to an object that can be
operation stores its variables into a reference to an object that can be
used from Python and is a valid value for the
Type
.
used from Python and is a valid value for the
`Type`
.
For example, in the current example, we have a
Type
which represents a
For example, in the current example, we have a
`Type`
which represents a
Python float. First, we will choose a corresponding C type. The
Python float. First, we will choose a corresponding C type. The
natural choice would be the primitive ``double`` type. Then, we need
natural choice would be the primitive ``double`` type. Then, we need
to write code that will take a ``PyObject*``, check that it is a
to write code that will take a ``PyObject*``, check that it is a
...
@@ -42,10 +42,10 @@ find here_.
...
@@ -42,10 +42,10 @@ find here_.
What needs to be defined
What needs to be defined
========================
========================
In order to be C-compatible,
a Type must define several additional
In order to be C-compatible,
the `Type` subclass interface `CType` must be used.
methods, which all start with the ``c_`` prefix. The complete list can
It defines several additional methods, which all start with the ``c_``
be found in the documentation for :class:`.gof.type.Type`. Here, we'll focus on
prefix. The complete list can be found in the documentation for
the most important ones:
:class:`.gof.type.CType`. Here, we'll focus on
the most important ones:
.. class:: CLinkerType
.. class:: CLinkerType
...
@@ -144,7 +144,7 @@ the most important ones:
...
@@ -144,7 +144,7 @@ the most important ones:
Each of these functions take two arguments, ``name`` and ``sub`` which
Each of these functions take two arguments, ``name`` and ``sub`` which
must be used to parameterize the C code they return. ``name`` is a
must be used to parameterize the C code they return. ``name`` is a
string which is chosen by the compiler to represent a :ref:`variable` of
string which is chosen by the compiler to represent a :ref:`variable` of
the
Type
in such a way that there are no name conflicts between
the
`CType`
in such a way that there are no name conflicts between
different pieces of data. Therefore, all variables declared in
different pieces of data. Therefore, all variables declared in
``c_declare`` should have a name which includes ``name``. Furthermore,
``c_declare`` should have a name which includes ``name``. Furthermore,
the name of the variable containing a pointer to the Python object
the name of the variable containing a pointer to the Python object
...
@@ -180,20 +180,19 @@ out:
...
@@ -180,20 +180,19 @@ out:
Defining the methods
Defining the methods
====================
====================
.. testsetup::
import theano
double = theano.Type()
**c_declare**
**c_declare**
.. testcode::
.. testcode::
def c_declare(name, sub):
from theano.gof.type import Generic
class double(Generic):
def c_declare(self, name, sub, check_input=True):
return """
return """
double %(name)s;
double %(name)s;
""" % dict(name = name)
""" % dict(name = name)
double.c_declare = c_declare
Very straightforward. All we need to do is write C code to declare a
Very straightforward. All we need to do is write C code to declare a
double. That double will be named whatever is passed to our function
double. That double will be named whatever is passed to our function
...
@@ -211,7 +210,7 @@ here). Also note that you cannot declare a variable called
...
@@ -211,7 +210,7 @@ here). Also note that you cannot declare a variable called
them.
them.
What you declare there is basically the C interface you are giving to
What you declare there is basically the C interface you are giving to
your
Type
. If you wish people to develop operations that make use of
your
`CType`
. If you wish people to develop operations that make use of
it, it's best to publish it somewhere.
it, it's best to publish it somewhere.
...
@@ -219,11 +218,10 @@ it, it's best to publish it somewhere.
...
@@ -219,11 +218,10 @@ it, it's best to publish it somewhere.
.. testcode::
.. testcode::
def c_init(
name, sub):
def c_init(self,
name, sub):
return """
return """
%(name)s = 0.0;
%(name)s = 0.0;
""" % dict(name = name)
""" % dict(name = name)
double.c_init = c_init
This function has to initialize the
This function has to initialize the
double we declared previously to a suitable value. This is useful if
double we declared previously to a suitable value. This is useful if
...
@@ -245,7 +243,7 @@ called, without knowing for sure which of the two.
...
@@ -245,7 +243,7 @@ called, without knowing for sure which of the two.
.. testcode::
.. testcode::
def c_extract(name, sub
):
def c_extract(self, name, sub, check_input=True
):
return """
return """
if (!PyFloat_Check(py_%(name)s)) {
if (!PyFloat_Check(py_%(name)s)) {
PyErr_SetString(PyExc_TypeError, "expected a float");
PyErr_SetString(PyExc_TypeError, "expected a float");
...
@@ -253,7 +251,6 @@ called, without knowing for sure which of the two.
...
@@ -253,7 +251,6 @@ called, without knowing for sure which of the two.
}
}
%(name)s = PyFloat_AsDouble(py_%(name)s);
%(name)s = PyFloat_AsDouble(py_%(name)s);
""" % dict(name = name, fail = sub['fail'])
""" % dict(name = name, fail = sub['fail'])
double.c_extract = c_extract
This method is slightly more sophisticated. What happens here is that
This method is slightly more sophisticated. What happens here is that
we have a reference to a Python object which Theano has placed in
we have a reference to a Python object which Theano has placed in
...
@@ -469,9 +466,9 @@ Final version
...
@@ -469,9 +466,9 @@ Final version
.. testcode::
.. testcode::
from theano
import gof
from theano
.gof.type import
class Double(
gof.
Type):
class Double(Type):
def filter(self, x, strict=False, allow_downcast=None):
def filter(self, x, strict=False, allow_downcast=None):
if strict and not isinstance(x, float):
if strict and not isinstance(x, float):
...
...
doc/extending/extending_theano.txt
浏览文件 @
f35aa65b
...
@@ -17,7 +17,7 @@ write a new one. Don't worry, Theano was designed to make it easy to add new
...
@@ -17,7 +17,7 @@ write a new one. Don't worry, Theano was designed to make it easy to add new
Ops, Types, and Optimizations.
Ops, Types, and Optimizations.
.. These first few pages will walk you through the definition of a new :ref:`type`,
.. These first few pages will walk you through the definition of a new :ref:`type`,
.. ``double``, and a basic arithmetic :ref:`operations <op>` on that
Type
.
.. ``double``, and a basic arithmetic :ref:`operations <op>` on that
`Type`
.
As an illustration, this tutorial shows how to write a simple Python-based
As an illustration, this tutorial shows how to write a simple Python-based
:ref:`operations <op>` which performs operations on
:ref:`operations <op>` which performs operations on
...
@@ -134,7 +134,7 @@ or :func:`make_thunk`.
...
@@ -134,7 +134,7 @@ or :func:`make_thunk`.
- it operates on the Variables found in
- it operates on the Variables found in
``*inputs`` in Theano's symbolic language to infer the type of
``*inputs`` in Theano's symbolic language to infer the type of
the symbolic output Variables. It creates output Variables of a suitable
the symbolic output Variables. It creates output Variables of a suitable
symbolic
Type
to serve as the outputs of this op's
symbolic
`Type`
to serve as the outputs of this op's
application.
application.
- it creates an Apply instance with the input and output Variable, and
- it creates an Apply instance with the input and output Variable, and
return the Apply instance.
return the Apply instance.
...
@@ -397,7 +397,7 @@ A common and easy way to ensure inputs are variables is to run them through
...
@@ -397,7 +397,7 @@ A common and easy way to ensure inputs are variables is to run them through
``as_tensor_variable``. This function leaves TensorType variables alone, raises
``as_tensor_variable``. This function leaves TensorType variables alone, raises
an error for non-TensorType variables, and copies any ``numpy.ndarray`` into
an error for non-TensorType variables, and copies any ``numpy.ndarray`` into
the storage for a TensorType Constant. The ``make_node`` method dictates the
the storage for a TensorType Constant. The ``make_node`` method dictates the
appropriate
Type
for all output variables.
appropriate
`Type`
for all output variables.
The ``perform`` method implements the Op's mathematical logic in Python.
The ``perform`` method implements the Op's mathematical logic in Python.
The inputs (here ``x``) are passed by value, but a single output is returned
The inputs (here ``x``) are passed by value, but a single output is returned
...
...
doc/extending/op.txt
浏览文件 @
f35aa65b
...
@@ -50,7 +50,7 @@ define the following methods.
...
@@ -50,7 +50,7 @@ define the following methods.
.. function:: make_node(*inputs)
.. function:: make_node(*inputs)
This method is responsible for creating output Variables of a
This method is responsible for creating output Variables of a
suitable symbolic
Type
to serve as the outputs of this Op's
suitable symbolic
`Type`
to serve as the outputs of this Op's
application. The Variables found in ``*inputs`` must be operated on
application. The Variables found in ``*inputs`` must be operated on
using Theano's symbolic language to compute the symbolic output
using Theano's symbolic language to compute the symbolic output
Variables. This method should put these outputs into an Apply
Variables. This method should put these outputs into an Apply
...
@@ -769,7 +769,7 @@ as first argument to Apply. We define ``perform`` using the function
...
@@ -769,7 +769,7 @@ as first argument to Apply. We define ``perform`` using the function
``fn`` passed in the constructor.
``fn`` passed in the constructor.
This design is a flexible way to define basic operations without
This design is a flexible way to define basic operations without
duplicating code. The same way a
Type
subclass represents a set of
duplicating code. The same way a
`Type`
subclass represents a set of
structurally similar types (see previous section), an `Op` subclass
structurally similar types (see previous section), an `Op` subclass
represents a set of structurally similar operations: operations that
represents a set of structurally similar operations: operations that
have the same input/output types, operations that only differ in one
have the same input/output types, operations that only differ in one
...
...
doc/extending/other_ops.txt
浏览文件 @
f35aa65b
...
@@ -266,8 +266,8 @@ along with pointers to the relevant documentation.
...
@@ -266,8 +266,8 @@ along with pointers to the relevant documentation.
primitive type. The C type associated with this Theano type is the
primitive type. The C type associated with this Theano type is the
represented C primitive itself.
represented C primitive itself.
* :ref:`SparseType <sparse_ops>` : Theano
type
used to represent sparse
* :ref:`SparseType <sparse_ops>` : Theano
`Type`
used to represent sparse
tensors. There is no equivalent C type for this Theano
Type
but you
tensors. There is no equivalent C type for this Theano
`Type`
but you
can split a sparse variable into its parts as TensorVariables. Those
can split a sparse variable into its parts as TensorVariables. Those
can then be used as inputs to an op with C code.
can then be used as inputs to an op with C code.
...
...
doc/extending/type.txt
浏览文件 @
f35aa65b
...
@@ -10,7 +10,7 @@ Making the double type
...
@@ -10,7 +10,7 @@ Making the double type
Type's contract
Type's contract
===============
===============
In Theano's framework, a ``Type`` (:class:`
.gof.type.
Type`)
In Theano's framework, a ``Type`` (:class:`Type`)
is any object which defines the following
is any object which defines the following
methods. To obtain the default methods described below, the Type should
methods. To obtain the default methods described below, the Type should
be an instance of ``Type`` or should be an instance of a
be an instance of ``Type`` or should be an instance of a
...
@@ -22,7 +22,7 @@ i.e. the same default argument names and values. If you wish to add
...
@@ -22,7 +22,7 @@ i.e. the same default argument names and values. If you wish to add
extra arguments to any of these methods, these extra arguments must have
extra arguments to any of these methods, these extra arguments must have
default values.
default values.
.. class::
Pure
Type
.. class:: Type
.. method:: filter(value, strict=False, allow_downcast=None)
.. method:: filter(value, strict=False, allow_downcast=None)
...
@@ -265,21 +265,21 @@ the Type is to instantiate a plain Type and set the needed fields:
...
@@ -265,21 +265,21 @@ the Type is to instantiate a plain Type and set the needed fields:
.. testcode::
.. testcode::
from theano
import gof
from theano
.gof.type import Type
double =
gof.
Type()
double = Type()
double.filter = filter
double.filter = filter
double.values_eq_approx = values_eq_approx
double.values_eq_approx = values_eq_approx
Another way to make this Type is to make a subclass of ``
gof.
Type``
Another way to make this Type is to make a subclass of ``Type``
and define ``filter`` and ``values_eq_approx`` in the subclass:
and define ``filter`` and ``values_eq_approx`` in the subclass:
.. code-block:: python
.. code-block:: python
from theano
import gof
from theano
.gof.type import Type
class Double(
gof.
Type):
class Double(Type):
def filter(self, x, strict=False, allow_downcast=None):
def filter(self, x, strict=False, allow_downcast=None):
# See code above.
# See code above.
...
@@ -300,9 +300,9 @@ instances of ``Double`` are technically the same Type. However, different
...
@@ -300,9 +300,9 @@ instances of ``Double`` are technically the same Type. However, different
.. testsetup::
.. testsetup::
from theano
import gof
from theano
.gof.type import Type
class Double(
gof.
Type):
class Double(Type):
def filter(self, x, strict=False, allow_downcast=None):
def filter(self, x, strict=False, allow_downcast=None):
if strict:
if strict:
...
@@ -399,9 +399,9 @@ Final version
...
@@ -399,9 +399,9 @@ Final version
.. testcode::
.. testcode::
from theano
import gof
from theano
.gof.type import Type
class Double(
gof.
Type):
class Double(Type):
def filter(self, x, strict=False, allow_downcast=None):
def filter(self, x, strict=False, allow_downcast=None):
if strict:
if strict:
...
@@ -432,4 +432,3 @@ Final version
...
@@ -432,4 +432,3 @@ Final version
We add one utility function, ``__str__``. That way, when we print
We add one utility function, ``__str__``. That way, when we print
``double``, it will print out something intelligible.
``double``, it will print out something intelligible.
doc/extending/using_params.txt
浏览文件 @
f35aa65b
...
@@ -48,10 +48,10 @@ The first thing you need to do is to define a Theano Type for your
...
@@ -48,10 +48,10 @@ The first thing you need to do is to define a Theano Type for your
params object. It doesn't have to be complete type because only the
params object. It doesn't have to be complete type because only the
following methods will be used for the type:
following methods will be used for the type:
- :meth:`filter <
Pure
Type.filter>`
- :meth:`filter <Type.filter>`
- :meth:`__eq__ <
Pure
Type.__eq__>`
- :meth:`__eq__ <Type.__eq__>`
- :meth:`__hash__ <
Pure
Type.__hash__>`
- :meth:`__hash__ <Type.__hash__>`
- :meth:`values_eq <
Pure
Type.values_eq>`
- :meth:`values_eq <Type.values_eq>`
Additionaly if you want to use your params with C code, you need to extend `COp`
Additionaly if you want to use your params with C code, you need to extend `COp`
and implement the following methods:
and implement the following methods:
...
...
tests/link/test_cc.py
浏览文件 @
f35aa65b
...
@@ -5,7 +5,7 @@ import theano
...
@@ -5,7 +5,7 @@ import theano
from
theano.gof
import
fg
from
theano.gof
import
fg
from
theano.gof.graph
import
Apply
,
Constant
,
Variable
from
theano.gof.graph
import
Apply
,
Constant
,
Variable
from
theano.gof.op
import
COp
from
theano.gof.op
import
COp
from
theano.gof.type
import
Type
from
theano.gof.type
import
C
Type
from
theano.link.basic
import
PerformLinker
from
theano.link.basic
import
PerformLinker
from
theano.link.c.basic
import
CLinker
,
DualLinker
,
OpWiseCLinker
from
theano.link.c.basic
import
CLinker
,
DualLinker
,
OpWiseCLinker
...
@@ -15,7 +15,7 @@ def as_variable(x):
...
@@ -15,7 +15,7 @@ def as_variable(x):
return
x
return
x
class
TDouble
(
Type
):
class
TDouble
(
C
Type
):
def
filter
(
self
,
data
,
strict
=
False
,
allow_downcast
=
False
):
def
filter
(
self
,
data
,
strict
=
False
,
allow_downcast
=
False
):
return
float
(
data
)
return
float
(
data
)
...
...
theano/__init__.py
浏览文件 @
f35aa65b
...
@@ -97,12 +97,12 @@ from theano.compile.function.types import FunctionMaker
...
@@ -97,12 +97,12 @@ from theano.compile.function.types import FunctionMaker
from
theano.gof
import
(
from
theano.gof
import
(
Apply
,
Apply
,
Constant
,
Constant
,
CType
,
FunctionGraph
,
FunctionGraph
,
Generic
,
Generic
,
InconsistencyError
,
InconsistencyError
,
Op
,
Op
,
OpenMPOp
,
OpenMPOp
,
Type
,
Variable
,
Variable
,
generic
,
generic
,
opt
,
opt
,
...
...
theano/compile/compiledir.py
浏览文件 @
f35aa65b
...
@@ -55,7 +55,7 @@ def cleanup():
...
@@ -55,7 +55,7 @@ def cleanup():
elif
obj
.
startswith
(
"c_compiler_str="
):
elif
obj
.
startswith
(
"c_compiler_str="
):
have_c_compiler
=
True
have_c_compiler
=
True
elif
isinstance
(
elif
isinstance
(
obj
,
(
theano
.
gof
.
Op
,
theano
.
gof
.
Type
)
obj
,
(
theano
.
gof
.
Op
,
theano
.
gof
.
C
Type
)
)
and
hasattr
(
obj
,
"c_code_cache_version"
):
)
and
hasattr
(
obj
,
"c_code_cache_version"
):
v
=
obj
.
c_code_cache_version
()
v
=
obj
.
c_code_cache_version
()
if
v
not
in
[(),
None
]
and
v
not
in
key
[
0
]:
if
v
not
in
[(),
None
]
and
v
not
in
key
[
0
]:
...
@@ -139,7 +139,7 @@ def print_compiledir_content():
...
@@ -139,7 +139,7 @@ def print_compiledir_content():
{
{
x
x
for
x
in
flatten
(
keydata
.
keys
)
for
x
in
flatten
(
keydata
.
keys
)
if
isinstance
(
x
,
theano
.
gof
.
Type
)
if
isinstance
(
x
,
theano
.
gof
.
C
Type
)
}
}
)
)
compile_start
=
compile_end
=
float
(
"nan"
)
compile_start
=
compile_end
=
float
(
"nan"
)
...
...
theano/compile/ops.py
浏览文件 @
f35aa65b
...
@@ -16,6 +16,7 @@ import theano
...
@@ -16,6 +16,7 @@ import theano
from
theano.gof
import
ParamsType
from
theano.gof
import
ParamsType
from
theano.gof.graph
import
Apply
,
Variable
from
theano.gof.graph
import
Apply
,
Variable
from
theano.gof.op
import
COp
,
Op
from
theano.gof.op
import
COp
,
Op
from
theano.gof.type
import
CType
from
theano.misc.safe_asarray
import
_asarray
from
theano.misc.safe_asarray
import
_asarray
...
@@ -619,11 +620,11 @@ def as_op(itypes, otypes, infer_shape=None):
...
@@ -619,11 +620,11 @@ def as_op(itypes, otypes, infer_shape=None):
"""
"""
if
not
isinstance
(
itypes
,
(
list
,
tuple
)):
if
not
isinstance
(
itypes
,
(
list
,
tuple
)):
itypes
=
[
itypes
]
itypes
=
[
itypes
]
if
any
(
not
isinstance
(
t
,
theano
.
Type
)
for
t
in
itypes
):
if
any
(
not
isinstance
(
t
,
C
Type
)
for
t
in
itypes
):
raise
TypeError
(
"itypes has to be a list of Theano types"
)
raise
TypeError
(
"itypes has to be a list of Theano types"
)
if
not
isinstance
(
otypes
,
(
list
,
tuple
)):
if
not
isinstance
(
otypes
,
(
list
,
tuple
)):
otypes
=
[
otypes
]
otypes
=
[
otypes
]
if
any
(
not
isinstance
(
t
,
theano
.
Type
)
for
t
in
otypes
):
if
any
(
not
isinstance
(
t
,
C
Type
)
for
t
in
otypes
):
raise
TypeError
(
"otypes has to be a list of Theano types"
)
raise
TypeError
(
"otypes has to be a list of Theano types"
)
# make sure they are lists and not tuples
# make sure they are lists and not tuples
...
...
theano/gof/__init__.py
浏览文件 @
f35aa65b
...
@@ -42,5 +42,5 @@ from theano.gof.toolbox import (
...
@@ -42,5 +42,5 @@ from theano.gof.toolbox import (
ReplaceValidate
,
ReplaceValidate
,
Validator
,
Validator
,
)
)
from
theano.gof.type
import
CEnumType
,
EnumList
,
EnumType
,
Generic
,
Type
,
generic
from
theano.gof.type
import
CEnumType
,
CType
,
EnumList
,
EnumType
,
Generic
,
generic
from
theano.gof.utils
import
MetaObject
,
MethodNotDefined
from
theano.gof.utils
import
MetaObject
,
MethodNotDefined
theano/gof/graph.py
浏览文件 @
f35aa65b
...
@@ -223,27 +223,26 @@ class Apply(Node):
...
@@ -223,27 +223,26 @@ class Apply(Node):
return
cp
return
cp
def
clone_with_new_inputs
(
self
,
inputs
,
strict
=
True
):
def
clone_with_new_inputs
(
self
,
inputs
,
strict
=
True
):
"""
"""Duplicate this `Apply` instance in a new graph.
Duplicate this Apply instance in a new graph.
Parameters
Parameters
----------
----------
inputs
inputs
: list of Variables
List of
Variable
instances to use as inputs.
List of
`Variable`
instances to use as inputs.
strict : bool
strict : bool
If
True
, the type fields of all the inputs must be equal
If
``True``
, the type fields of all the inputs must be equal
to the current ones (or compatible, for instance
Tensor
/
to the current ones (or compatible, for instance
`Tensor`
/
GpuArray
of the same dtype and broadcastable patterns,
`GpuArray`
of the same dtype and broadcastable patterns,
in which case they will be converted into current
Type
), and
in which case they will be converted into current
`Type`
), and
returned outputs are guaranteed to have the same types as
returned outputs are guaranteed to have the same types as
self.outputs. If False
, then there's no guarantee that the
``self.outputs``. If ``False``
, then there's no guarantee that the
clone's outputs will have the same types as
self.outputs
,
clone's outputs will have the same types as
``self.outputs``
,
and cloning may not even be possible (it depends on the
Op
).
and cloning may not even be possible (it depends on the
`Op`
).
Returns
Returns
-------
-------
object
object
An
Apply instance with the same op
but different outputs.
An
`Apply` instance with the same `Op`
but different outputs.
"""
"""
assert
isinstance
(
inputs
,
(
list
,
tuple
))
assert
isinstance
(
inputs
,
(
list
,
tuple
))
...
@@ -672,18 +671,18 @@ def walk(
...
@@ -672,18 +671,18 @@ def walk(
Parameters
Parameters
----------
----------
nodes: deque
nodes
: deque
The nodes from which to start walking.
The nodes from which to start walking.
expand: callable
expand
: callable
A callable that is applied to each node in `nodes`, the results of
A callable that is applied to each node in `nodes`, the results of
which are either new nodes to visit or ``None``.
which are either new nodes to visit or ``None``.
bfs: bool
bfs
: bool
If ``True``, breath first search is used; otherwise, depth first
If ``True``, breath first search is used; otherwise, depth first
search.
search.
return_children: bool
return_children
: bool
If ``True``, each output node will be accompanied by the output of
If ``True``, each output node will be accompanied by the output of
`expand` (i.e. the corresponding child nodes).
`expand` (i.e. the corresponding child nodes).
hash_fn: callable
hash_fn
: callable
The function used to produce hashes of the elements in `nodes`.
The function used to produce hashes of the elements in `nodes`.
The default is ``id``.
The default is ``id``.
...
@@ -735,10 +734,10 @@ def ancestors(
...
@@ -735,10 +734,10 @@ def ancestors(
Parameters
Parameters
----------
----------
graphs: list of `Variable` instances
graphs
: list of `Variable` instances
Output `Variable` instances from which to search backward through
Output `Variable` instances from which to search backward through
owners.
owners.
blockers: list of `Variable` instances
blockers
: list of `Variable` instances
A collection of `Variable`s that, when found, prevent the graph search
A collection of `Variable`s that, when found, prevent the graph search
from preceding from that point.
from preceding from that point.
...
@@ -764,10 +763,10 @@ def graph_inputs(
...
@@ -764,10 +763,10 @@ def graph_inputs(
Parameters
Parameters
----------
----------
graphs: list of `Variable` instances
graphs
: list of `Variable` instances
Output `Variable` instances from which to search backward through
Output `Variable` instances from which to search backward through
owners.
owners.
blockers: list of `Variable` instances
blockers
: list of `Variable` instances
A collection of `Variable`s that, when found, prevent the graph search
A collection of `Variable`s that, when found, prevent the graph search
from preceding from that point.
from preceding from that point.
...
@@ -788,9 +787,9 @@ def vars_between(
...
@@ -788,9 +787,9 @@ def vars_between(
Parameters
Parameters
----------
----------
ins: list
ins
: list
Input `Variable`s.
Input `Variable`s.
outs: list
outs
: list
Output `Variable`s.
Output `Variable`s.
Yields
Yields
...
@@ -817,9 +816,9 @@ def orphans_between(
...
@@ -817,9 +816,9 @@ def orphans_between(
Parameters
Parameters
----------
----------
ins: list
ins
: list
Input `Variable`s.
Input `Variable`s.
outs: list
outs
: list
Output `Variable`s.
Output `Variable`s.
Yields
Yields
...
@@ -845,9 +844,9 @@ def applys_between(
...
@@ -845,9 +844,9 @@ def applys_between(
Parameters
Parameters
----------
----------
ins: list
ins
: list
Input `Variable`s.
Input `Variable`s.
outs: list
outs
: list
Output `Variable`s.
Output `Variable`s.
Yields
Yields
...
@@ -972,15 +971,15 @@ def general_toposort(
...
@@ -972,15 +971,15 @@ def general_toposort(
Parameters
Parameters
----------
----------
deps: callable
deps
: callable
A python function that takes a node as input and returns its dependence.
A python function that takes a node as input and returns its dependence.
compute_deps_cache: optional
compute_deps_cache
: optional
If provided deps_cache should also be provided. This is a function like
If provided deps_cache should also be provided. This is a function like
deps, but that also cache its results in a dict passed as deps_cache.
deps, but that also cache its results in a dict passed as deps_cache.
deps_cache: dict
deps_cache
: dict
A dict mapping nodes to their children. This is populated by
A dict mapping nodes to their children. This is populated by
`compute_deps_cache`.
`compute_deps_cache`.
clients: dict
clients
: dict
If a dict is passed it will be filled with a mapping of
If a dict is passed it will be filled with a mapping of
nodes-to-clients for each node in the subgraph.
nodes-to-clients for each node in the subgraph.
...
@@ -1357,9 +1356,9 @@ def list_of_nodes(
...
@@ -1357,9 +1356,9 @@ def list_of_nodes(
Parameters
Parameters
----------
----------
inputs: list of Variable
inputs
: list of Variable
Input `Variable`s.
Input `Variable`s.
outputs: list of Variable
outputs
: list of Variable
Output `Variable`s.
Output `Variable`s.
"""
"""
...
@@ -1380,9 +1379,9 @@ def is_in_ancestors(l_apply: Apply, f_node: Apply) -> bool:
...
@@ -1380,9 +1379,9 @@ def is_in_ancestors(l_apply: Apply, f_node: Apply) -> bool:
Parameters
Parameters
----------
----------
l_apply: Apply
l_apply
: Apply
The node to walk.
The node to walk.
f_apply: Apply
f_apply
: Apply
The node to find in `l_apply`.
The node to find in `l_apply`.
Returns
Returns
...
...
theano/gof/null_type.py
浏览文件 @
f35aa65b
from
theano.gof.type
import
Type
from
theano.gof.type
import
C
Type
class
NullType
(
Type
):
class
NullType
(
C
Type
):
"""
"""
A type that allows no values.
A type that allows no values.
...
...
theano/gof/params_type.py
浏览文件 @
f35aa65b
...
@@ -116,7 +116,7 @@ for more info about enumeration aliases).
...
@@ -116,7 +116,7 @@ for more info about enumeration aliases).
import
hashlib
import
hashlib
import
re
import
re
from
theano.gof.type
import
EnumType
,
Type
from
theano.gof.type
import
CType
,
Enum
Type
from
theano.gof.utils
import
MethodNotDefined
from
theano.gof.utils
import
MethodNotDefined
...
@@ -315,14 +315,15 @@ class Params(dict):
...
@@ -315,14 +315,15 @@ class Params(dict):
return
not
self
.
__eq__
(
other
)
return
not
self
.
__eq__
(
other
)
class
ParamsType
(
Type
):
class
ParamsType
(
C
Type
):
"""
"""
This class can create a struct of Theano types (like TensorType, GpuArrayType, etc.)
This class can create a struct of Theano types (like `TensorType`,
to be used as a convenience op parameter wrapping many data.
`GpuArrayType`, etc.) to be used as a convenience op parameter wrapping
many data.
ParamsType constructor takes key-value args.
`ParamsType` constructor takes key-value args. Key will be the name of the
Key will be the name of the attribute in the struct.
attribute in the struct. Value is the Theano type of this attribute,
Value is the Theano type of this attribute, ie. an instance of (a subclass of) :class:`
Type`
ie. an instance of (a subclass of) :class:`C
Type`
(eg. ``TensorType('int64', (False,))``).
(eg. ``TensorType('int64', (False,))``).
In a Python code any attribute named ``key`` will be available via::
In a Python code any attribute named ``key`` will be available via::
...
@@ -337,7 +338,8 @@ class ParamsType(Type):
...
@@ -337,7 +338,8 @@ class ParamsType(Type):
.. note::
.. note::
This Type is not complete and should never be used for regular graph operations.
This `Type` is not complete and should never be used for regular graph
operations.
"""
"""
...
@@ -358,9 +360,9 @@ class ParamsType(Type):
...
@@ -358,9 +360,9 @@ class ParamsType(Type):
)
)
type_instance
=
kwargs
[
attribute_name
]
type_instance
=
kwargs
[
attribute_name
]
type_name
=
type_instance
.
__class__
.
__name__
type_name
=
type_instance
.
__class__
.
__name__
if
not
isinstance
(
type_instance
,
Type
):
if
not
isinstance
(
type_instance
,
C
Type
):
raise
TypeError
(
raise
TypeError
(
'ParamsType: attribute "
%
s" should inherit from Theano Type, got "
%
s".'
'ParamsType: attribute "
%
s" should inherit from Theano
C
Type, got "
%
s".'
%
(
attribute_name
,
type_name
)
%
(
attribute_name
,
type_name
)
)
)
...
...
theano/gof/type.py
浏览文件 @
f35aa65b
"""
"""The `Type` classes."""
WRITEME
Defines the `Type` class.
"""
import
ctypes
import
ctypes
...
@@ -16,12 +11,13 @@ from theano.gof import graph, utils
...
@@ -16,12 +11,13 @@ from theano.gof import graph, utils
from
theano.gof.op
import
COp
from
theano.gof.op
import
COp
from
theano.gof.utils
import
MetaObject
,
MethodNotDefined
from
theano.gof.utils
import
MetaObject
,
MethodNotDefined
from
theano.link.c.interface
import
CLinkerType
from
theano.link.c.interface
import
CLinkerType
from
theano.utils
import
Singleton
__docformat__
=
"restructuredtext en"
__docformat__
=
"restructuredtext en"
class
Pure
Type
:
class
Type
:
"""
"""
Interface specification for variable type instances.
Interface specification for variable type instances.
...
@@ -35,10 +31,10 @@ class PureType:
...
@@ -35,10 +31,10 @@ class PureType:
"""
"""
# the type that will be created by call to make_variable.
# the type that will be created by
a
call to make_variable.
Variable
=
graph
.
Variable
Variable
=
graph
.
Variable
# the type that will be created by call to make_constant
# the type that will be created by
a
call to make_constant
Constant
=
graph
.
Constant
Constant
=
graph
.
Constant
def
filter
(
self
,
data
,
strict
=
False
,
allow_downcast
=
None
):
def
filter
(
self
,
data
,
strict
=
False
,
allow_downcast
=
None
):
...
@@ -217,13 +213,9 @@ class PureType:
...
@@ -217,13 +213,9 @@ class PureType:
"""
"""
_nothing
=
"""
class
CType
(
MetaObject
,
Type
,
CLinkerType
):
"""
class
Type
(
MetaObject
,
PureType
,
CLinkerType
):
"""
"""
Convenience wrapper combining `
Pure
Type` and `CLinkerType`.
Convenience wrapper combining `Type` and `CLinkerType`.
Theano comes with several subclasses of such as:
Theano comes with several subclasses of such as:
...
@@ -265,48 +257,11 @@ class Type(MetaObject, PureType, CLinkerType):
...
@@ -265,48 +257,11 @@ class Type(MetaObject, PureType, CLinkerType):
"""
"""
class
SingletonType
(
Type
):
class
Generic
(
CType
,
Singleton
):
"""
Convenient Base class for a Type subclass with no attributes.
It saves having to implement __eq__ and __hash__.
"""
__instance
=
None
def
__new__
(
cls
):
# If sub-subclass of SingletonType don't redeclare __instance
# when we look for it, we will find it in the subclass. We
# don't want that, so we check the class. When we add one, we
# add one only to the current class, so all is working
# correctly.
if
cls
.
__instance
is
None
or
not
isinstance
(
cls
.
__instance
,
cls
):
cls
.
__instance
=
Type
.
__new__
(
cls
)
return
cls
.
__instance
def
__str__
(
self
):
return
self
.
__class__
.
__name__
# even if we try to make a singleton, this do not always work. So
# we compare the type. See test_type_other.test_none_Constant for
# an exmple. So we need to implement __eq__ and __hash__
def
__eq__
(
self
,
other
):
if
self
is
other
:
return
True
if
type
(
self
)
is
type
(
other
):
return
True
return
False
def
__hash__
(
self
):
return
hash
(
type
(
self
))
class
Generic
(
SingletonType
):
"""
"""
Represents a generic Python object.
Represents a generic Python object.
This class implements the `
Pure
Type` and `CLinkerType` interfaces
This class implements the `
C
Type` and `CLinkerType` interfaces
for generic PyObject instances.
for generic PyObject instances.
EXAMPLE of what this means, or when you would use this type.
EXAMPLE of what this means, or when you would use this type.
...
@@ -400,7 +355,7 @@ class _make_cdata(COp):
...
@@ -400,7 +355,7 @@ class _make_cdata(COp):
return
(
0
,
self
.
rtype
.
version
)
return
(
0
,
self
.
rtype
.
version
)
class
CDataType
(
Type
):
class
CDataType
(
C
Type
):
"""
"""
Represents opaque C data to be passed around. The intent is to
Represents opaque C data to be passed around. The intent is to
ease passing arbitrary data between ops C code.
ease passing arbitrary data between ops C code.
...
@@ -613,7 +568,7 @@ class CDataTypeConstant(graph.Constant):
...
@@ -613,7 +568,7 @@ class CDataTypeConstant(graph.Constant):
CDataType
.
Constant
=
CDataTypeConstant
CDataType
.
Constant
=
CDataTypeConstant
class
EnumType
(
Type
,
dict
):
class
EnumType
(
C
Type
,
dict
):
"""
"""
Main subclasses:
Main subclasses:
- :class:`EnumList`
- :class:`EnumList`
...
@@ -804,12 +759,12 @@ class EnumType(Type, dict):
...
@@ -804,12 +759,12 @@ class EnumType(Type, dict):
def
__getattr__
(
self
,
key
):
def
__getattr__
(
self
,
key
):
if
key
in
self
:
if
key
in
self
:
return
self
[
key
]
return
self
[
key
]
return
Type
.
__getattr__
(
self
,
key
)
return
C
Type
.
__getattr__
(
self
,
key
)
def
__setattr__
(
self
,
key
,
value
):
def
__setattr__
(
self
,
key
,
value
):
if
key
in
self
:
if
key
in
self
:
raise
NotImplementedError
(
"constant values are immutable."
)
raise
NotImplementedError
(
"constant values are immutable."
)
Type
.
__setattr__
(
self
,
key
,
value
)
C
Type
.
__setattr__
(
self
,
key
,
value
)
def
__setitem__
(
self
,
key
,
value
):
def
__setitem__
(
self
,
key
,
value
):
raise
NotImplementedError
(
"constant values are immutable."
)
raise
NotImplementedError
(
"constant values are immutable."
)
...
...
theano/gpuarray/basic_ops.py
浏览文件 @
f35aa65b
...
@@ -12,7 +12,7 @@ from theano.gof.graph import Apply, Variable
...
@@ -12,7 +12,7 @@ from theano.gof.graph import Apply, Variable
from
theano.gof.op
import
COp
,
ExternalCOp
,
Op
from
theano.gof.op
import
COp
,
ExternalCOp
,
Op
from
theano.gof.opt
import
copy_stack_trace
from
theano.gof.opt
import
copy_stack_trace
from
theano.gof.params_type
import
ParamsType
from
theano.gof.params_type
import
ParamsType
from
theano.gof.type
import
Type
from
theano.gof.type
import
C
Type
from
theano.gof.utils
import
MethodNotDefined
from
theano.gof.utils
import
MethodNotDefined
from
theano.gradient
import
grad_undefined
from
theano.gradient
import
grad_undefined
from
theano.link.c.interface
import
HideC
from
theano.link.c.interface
import
HideC
...
@@ -220,7 +220,7 @@ class Kernel:
...
@@ -220,7 +220,7 @@ class Kernel:
def
get_dtype
(
t
):
def
get_dtype
(
t
):
if
isinstance
(
t
,
str
):
if
isinstance
(
t
,
str
):
return
np
.
dtype
(
t
)
return
np
.
dtype
(
t
)
elif
isinstance
(
t
,
Type
):
elif
isinstance
(
t
,
C
Type
):
return
t
.
dtype
return
t
.
dtype
elif
isinstance
(
t
,
Variable
):
elif
isinstance
(
t
,
Variable
):
return
t
.
type
.
dtype
return
t
.
type
.
dtype
...
...
theano/gpuarray/subtensor.py
浏览文件 @
f35aa65b
...
@@ -6,6 +6,7 @@ import theano.tensor as tt
...
@@ -6,6 +6,7 @@ import theano.tensor as tt
from
theano
import
gof
from
theano
import
gof
from
theano.gof.op
import
COp
,
Op
from
theano.gof.op
import
COp
,
Op
from
theano.gof.params_type
import
ParamsType
from
theano.gof.params_type
import
ParamsType
from
theano.gof.type
import
CType
from
theano.gradient
import
grad_not_implemented
from
theano.gradient
import
grad_not_implemented
from
theano.link.c.interface
import
HideC
from
theano.link.c.interface
import
HideC
from
theano.scalar
import
bool
as
bool_t
from
theano.scalar
import
bool
as
bool_t
...
@@ -160,7 +161,7 @@ class GpuSubtensor(HideC, Subtensor):
...
@@ -160,7 +161,7 @@ class GpuSubtensor(HideC, Subtensor):
return
"0"
,
1
return
"0"
,
1
elif
isinstance
(
idx
,
(
np
.
integer
,
int
)):
elif
isinstance
(
idx
,
(
np
.
integer
,
int
)):
return
str
(
idx
),
0
return
str
(
idx
),
0
elif
isinstance
(
idx
,
gof
.
Type
):
elif
isinstance
(
idx
,
C
Type
):
return
indices
.
pop
(
0
),
0
return
indices
.
pop
(
0
),
0
else
:
else
:
assert
0
,
idx
assert
0
,
idx
...
@@ -195,7 +196,7 @@ class GpuSubtensor(HideC, Subtensor):
...
@@ -195,7 +196,7 @@ class GpuSubtensor(HideC, Subtensor):
file
=
sio
,
file
=
sio
,
)
)
else
:
else
:
if
isinstance
(
idx
,
gof
.
Type
):
if
isinstance
(
idx
,
C
Type
):
start
=
indices
.
pop
(
0
)
start
=
indices
.
pop
(
0
)
elif
isinstance
(
idx
,
(
np
.
integer
,
int
)):
elif
isinstance
(
idx
,
(
np
.
integer
,
int
)):
start
=
idx
start
=
idx
...
@@ -263,7 +264,7 @@ class GpuIncSubtensor(IncSubtensor):
...
@@ -263,7 +264,7 @@ class GpuIncSubtensor(IncSubtensor):
indices
=
list
(
reversed
(
inputs
[
2
:]))
indices
=
list
(
reversed
(
inputs
[
2
:]))
def
convert
(
entry
):
def
convert
(
entry
):
if
isinstance
(
entry
,
gof
.
Type
):
if
isinstance
(
entry
,
C
Type
):
rval
=
indices
.
pop
()
rval
=
indices
.
pop
()
return
rval
return
rval
elif
isinstance
(
entry
,
slice
):
elif
isinstance
(
entry
,
slice
):
...
...
theano/gpuarray/type.py
浏览文件 @
f35aa65b
...
@@ -6,7 +6,7 @@ import warnings
...
@@ -6,7 +6,7 @@ import warnings
import
numpy
as
np
import
numpy
as
np
import
theano
import
theano
from
theano
import
Constant
,
Type
,
Variable
,
config
,
scalar
,
tensor
from
theano
import
Constant
,
C
Type
,
Variable
,
config
,
scalar
,
tensor
from
theano.compile
import
SharedVariable
from
theano.compile
import
SharedVariable
from
theano.misc.safe_asarray
import
_asarray
from
theano.misc.safe_asarray
import
_asarray
from
theano.tensor.type
import
TensorType
from
theano.tensor.type
import
TensorType
...
@@ -127,7 +127,7 @@ def _unreg_context(name):
...
@@ -127,7 +127,7 @@ def _unreg_context(name):
del
_context_reg
[
name
]
del
_context_reg
[
name
]
class
GpuArrayType
(
Type
):
class
GpuArrayType
(
C
Type
):
"""
"""
The type that represents an array on a gpu.
The type that represents an array on a gpu.
...
@@ -173,7 +173,7 @@ class GpuArrayType(Type):
...
@@ -173,7 +173,7 @@ class GpuArrayType(Type):
See Also
See Also
--------
--------
theano.gof.type.
Pure
Type
theano.gof.type.Type
"""
"""
...
@@ -883,7 +883,7 @@ theano.compile.register_specify_shape_c_code(
...
@@ -883,7 +883,7 @@ theano.compile.register_specify_shape_c_code(
)
)
class
GpuContextType
(
Type
):
class
GpuContextType
(
C
Type
):
"""
"""
Minimal type used for passing contexts to nodes.
Minimal type used for passing contexts to nodes.
...
...
theano/gradient.py
浏览文件 @
f35aa65b
...
@@ -119,7 +119,7 @@ def grad_undefined(op, x_pos, x, comment=""):
...
@@ -119,7 +119,7 @@ def grad_undefined(op, x_pos, x, comment=""):
)()
)()
class
DisconnectedType
(
theano
.
gof
.
type
.
Type
):
class
DisconnectedType
(
theano
.
gof
.
type
.
C
Type
):
"""A type indicating that a variable is a result
"""A type indicating that a variable is a result
of taking the gradient of c with respect to x
of taking the gradient of c with respect to x
...
...
theano/link/basic.py
浏览文件 @
f35aa65b
...
@@ -4,7 +4,7 @@ from copy import copy, deepcopy
...
@@ -4,7 +4,7 @@ from copy import copy, deepcopy
from
theano.configdefaults
import
config
from
theano.configdefaults
import
config
from
theano.gof.fg
import
FunctionGraph
from
theano.gof.fg
import
FunctionGraph
from
theano.gof.graph
import
Apply
from
theano.gof.graph
import
Apply
from
theano.gof.type
import
Type
from
theano.gof.type
import
C
Type
from
theano.link.utils
import
gc_helper
,
map_storage
,
raise_with_op
,
streamline
from
theano.link.utils
import
gc_helper
,
map_storage
,
raise_with_op
,
streamline
from
theano.utils
import
deprecated
,
difference
,
to_return_values
from
theano.utils
import
deprecated
,
difference
,
to_return_values
...
@@ -45,7 +45,7 @@ class Container:
...
@@ -45,7 +45,7 @@ class Container:
):
):
if
not
isinstance
(
storage
,
list
)
or
not
len
(
storage
)
>=
1
:
if
not
isinstance
(
storage
,
list
)
or
not
len
(
storage
)
>=
1
:
raise
TypeError
(
"storage must be a list of length at least one"
)
raise
TypeError
(
"storage must be a list of length at least one"
)
if
isinstance
(
r
,
Type
):
if
isinstance
(
r
,
C
Type
):
self
.
type
=
r
self
.
type
=
r
else
:
else
:
self
.
type
=
r
.
type
self
.
type
=
r
.
type
...
...
theano/link/c/basic.py
浏览文件 @
f35aa65b
...
@@ -6,6 +6,7 @@ import logging
...
@@ -6,6 +6,7 @@ import logging
import
os
import
os
import
sys
import
sys
from
collections
import
defaultdict
from
collections
import
defaultdict
from
contextlib
import
suppress
from
copy
import
copy
from
copy
import
copy
from
io
import
StringIO
from
io
import
StringIO
...
@@ -25,6 +26,7 @@ from theano.link.c.cmodule import (
...
@@ -25,6 +26,7 @@ from theano.link.c.cmodule import (
dlimport_workdir
,
dlimport_workdir
,
)
)
from
theano.link.c.cmodule
import
get_module_cache
as
_get_module_cache
from
theano.link.c.cmodule
import
get_module_cache
as
_get_module_cache
from
theano.link.c.interface
import
CLinkerObject
,
CLinkerOp
,
CLinkerType
from
theano.link.utils
import
gc_helper
,
map_storage
,
raise_with_op
,
streamline
from
theano.link.utils
import
gc_helper
,
map_storage
,
raise_with_op
,
streamline
from
theano.utils
import
difference
,
uniq
from
theano.utils
import
difference
,
uniq
...
@@ -667,13 +669,13 @@ class CLinker(Linker):
...
@@ -667,13 +669,13 @@ class CLinker(Linker):
self
.
consts
=
[]
self
.
consts
=
[]
# Move c type from orphans (theano.scalar.Scalar) to self.consts
# Move c type from orphans (theano.scalar.Scalar) to self.consts
for
variable
in
self
.
orphans
:
for
variable
in
self
.
orphans
:
if
isinstance
(
variable
,
Constant
):
if
isinstance
(
variable
,
Constant
)
and
isinstance
(
try
:
variable
.
type
,
CLinkerType
):
with
suppress
(
MethodNotDefined
,
NotImplementedError
):
variable
.
type
.
c_literal
(
variable
.
data
)
variable
.
type
.
c_literal
(
variable
.
data
)
self
.
consts
.
append
(
variable
)
self
.
consts
.
append
(
variable
)
self
.
orphans
.
remove
(
variable
)
self
.
orphans
.
remove
(
variable
)
except
(
MethodNotDefined
,
NotImplementedError
):
pass
self
.
temps
=
list
(
self
.
temps
=
list
(
set
(
self
.
variables
)
set
(
self
.
variables
)
...
@@ -721,6 +723,10 @@ class CLinker(Linker):
...
@@ -721,6 +723,10 @@ class CLinker(Linker):
id
=
1
id
=
1
for
variable
in
self
.
variables
:
for
variable
in
self
.
variables
:
if
not
isinstance
(
variable
.
type
,
CLinkerType
):
raise
NotImplementedError
(
f
"Type of {variable} cannot produce C code"
)
sub
=
dict
(
failure_var
=
failure_var
)
sub
=
dict
(
failure_var
=
failure_var
)
# it might be possible to inline constant variables as C literals
# it might be possible to inline constant variables as C literals
...
@@ -816,6 +822,11 @@ class CLinker(Linker):
...
@@ -816,6 +822,11 @@ class CLinker(Linker):
for
node_num
,
node
in
enumerate
(
self
.
node_order
):
for
node_num
,
node
in
enumerate
(
self
.
node_order
):
op
=
node
.
op
if
not
isinstance
(
op
,
CLinkerOp
):
raise
NotImplementedError
(
f
"{op} cannot produce C code"
)
sub
=
dict
(
failure_var
=
failure_var
)
sub
=
dict
(
failure_var
=
failure_var
)
params
=
node
.
run_params
()
params
=
node
.
run_params
()
...
@@ -849,56 +860,43 @@ class CLinker(Linker):
...
@@ -849,56 +860,43 @@ class CLinker(Linker):
struct_init
=
""
struct_init
=
""
struct_cleanup
=
""
struct_cleanup
=
""
op
=
node
.
op
with
suppress
(
MethodNotDefined
):
# type-specific support code
try
:
c_support_code_apply
.
append
(
op
.
c_support_code_apply
(
node
,
name
))
c_support_code_apply
.
append
(
op
.
c_support_code_apply
(
node
,
name
))
except
MethodNotDefined
:
pass
else
:
# The following will be executed if the "try" block succeeds
assert
isinstance
(
c_support_code_apply
[
-
1
],
str
),
(
assert
isinstance
(
c_support_code_apply
[
-
1
],
str
),
(
str
(
node
.
op
)
+
" didn't return a string for c_support_code_apply"
str
(
node
.
op
)
+
" didn't return a string for c_support_code_apply"
)
)
try
:
with
suppress
(
MethodNotDefined
)
:
c_init_code_apply
.
append
(
op
.
c_init_code_apply
(
node
,
name
))
c_init_code_apply
.
append
(
op
.
c_init_code_apply
(
node
,
name
))
except
MethodNotDefined
:
pass
else
:
assert
isinstance
(
c_init_code_apply
[
-
1
],
str
),
(
assert
isinstance
(
c_init_code_apply
[
-
1
],
str
),
(
str
(
node
.
op
)
+
" didn't return a string for c_init_code_apply"
str
(
node
.
op
)
+
" didn't return a string for c_init_code_apply"
)
)
try
:
with
suppress
(
MethodNotDefined
)
:
struct_init
=
op
.
c_init_code_struct
(
node
,
name
,
sub_struct
)
struct_init
=
op
.
c_init_code_struct
(
node
,
name
,
sub_struct
)
assert
isinstance
(
struct_init
,
str
),
(
assert
isinstance
(
struct_init
,
str
),
(
str
(
node
.
op
)
+
" didn't return a string for c_init_code_struct"
str
(
node
.
op
)
+
" didn't return a string for c_init_code_struct"
)
)
except
MethodNotDefined
:
pass
try
:
with
suppress
(
MethodNotDefined
)
:
struct_support
=
op
.
c_support_code_struct
(
node
,
name
)
struct_support
=
op
.
c_support_code_struct
(
node
,
name
)
assert
isinstance
(
struct_support
,
str
),
(
assert
isinstance
(
struct_support
,
str
),
(
str
(
node
.
op
)
+
" didn't return a string for c_support_code_struct"
str
(
node
.
op
)
+
" didn't return a string for c_support_code_struct"
)
)
except
MethodNotDefined
:
pass
try
:
with
suppress
(
MethodNotDefined
)
:
struct_cleanup
=
op
.
c_cleanup_code_struct
(
node
,
name
)
struct_cleanup
=
op
.
c_cleanup_code_struct
(
node
,
name
)
assert
isinstance
(
struct_cleanup
,
str
),
(
assert
isinstance
(
struct_cleanup
,
str
),
(
str
(
node
.
op
)
+
" didn't return a string for c_cleanup_code_struct"
str
(
node
.
op
)
+
" didn't return a string for c_cleanup_code_struct"
)
)
except
MethodNotDefined
:
pass
# emit c_code
# emit c_code
try
:
try
:
behavior
=
op
.
c_code
(
node
,
name
,
isyms
,
osyms
,
sub
)
behavior
=
op
.
c_code
(
node
,
name
,
isyms
,
osyms
,
sub
)
except
MethodNotDefined
:
except
MethodNotDefined
:
raise
NotImplementedError
(
f
"{op} cannot produce C code"
)
raise
NotImplementedError
(
f
"{op} cannot produce C code"
)
assert
isinstance
(
assert
isinstance
(
behavior
,
str
behavior
,
str
),
f
"{node.op} didn't return a string for c_code"
),
f
"{node.op} didn't return a string for c_code"
...
@@ -987,14 +985,12 @@ class CLinker(Linker):
...
@@ -987,14 +985,12 @@ class CLinker(Linker):
)
)
# generic support code
# generic support code
for
x
in
[
y
.
type
for
y
in
self
.
variables
]
+
[
y
.
op
for
y
in
self
.
node_order
]:
for
x
in
[
y
.
type
for
y
in
self
.
variables
]
+
[
y
.
op
for
y
in
self
.
node_order
]:
try
:
with
suppress
(
MethodNotDefined
)
:
support_code
=
x
.
c_support_code
()
support_code
=
x
.
c_support_code
()
if
isinstance
(
support_code
,
list
):
if
isinstance
(
support_code
,
list
):
ret
.
extend
(
support_code
)
ret
.
extend
(
support_code
)
else
:
else
:
ret
.
append
(
support_code
)
ret
.
append
(
support_code
)
except
MethodNotDefined
:
pass
return
ret
return
ret
def
compile_args
(
self
):
def
compile_args
(
self
):
...
@@ -1026,20 +1022,20 @@ class CLinker(Linker):
...
@@ -1026,20 +1022,20 @@ class CLinker(Linker):
c_compiler
=
self
.
c_compiler
()
c_compiler
=
self
.
c_compiler
()
for
x
in
[
y
.
type
for
y
in
self
.
variables
]
+
[
y
.
op
for
y
in
self
.
node_order
]:
for
x
in
[
y
.
type
for
y
in
self
.
variables
]
+
[
y
.
op
for
y
in
self
.
node_order
]:
try
:
if
isinstance
(
x
,
CLinkerObject
):
with
suppress
(
MethodNotDefined
):
try
:
try
:
ret
+=
x
.
c_compile_args
(
c_compiler
)
ret
+=
x
.
c_compile_args
(
c_compiler
)
except
TypeError
:
except
TypeError
:
ret
+=
x
.
c_compile_args
()
ret
+=
x
.
c_compile_args
()
except
MethodNotDefined
:
pass
ret
=
uniq
(
ret
)
# to remove duplicate
ret
=
uniq
(
ret
)
# to remove duplicate
# The args set by the compiler include the user flags. We do not want
# The args set by the compiler include the user flags. We do not want
# to reorder them
# to reorder them
ret
+=
c_compiler
.
compile_args
()
ret
+=
c_compiler
.
compile_args
()
for
x
in
[
y
.
type
for
y
in
self
.
variables
]
+
[
y
.
op
for
y
in
self
.
node_order
]:
for
x
in
[
y
.
type
for
y
in
self
.
variables
]
+
[
y
.
op
for
y
in
self
.
node_order
]:
try
:
if
isinstance
(
x
,
CLinkerObject
):
with
suppress
(
MethodNotDefined
):
try
:
try
:
no_comp
=
x
.
c_no_compile_args
(
c_compiler
)
no_comp
=
x
.
c_no_compile_args
(
c_compiler
)
except
TypeError
:
except
TypeError
:
...
@@ -1049,8 +1045,6 @@ class CLinker(Linker):
...
@@ -1049,8 +1045,6 @@ class CLinker(Linker):
ret
.
remove
(
i
)
ret
.
remove
(
i
)
except
ValueError
:
except
ValueError
:
pass
# in case the value is not there
pass
# in case the value is not there
except
MethodNotDefined
:
pass
return
ret
return
ret
def
headers
(
self
):
def
headers
(
self
):
...
@@ -1064,13 +1058,12 @@ class CLinker(Linker):
...
@@ -1064,13 +1058,12 @@ class CLinker(Linker):
ret
=
[]
ret
=
[]
c_compiler
=
self
.
c_compiler
()
c_compiler
=
self
.
c_compiler
()
for
x
in
[
y
.
type
for
y
in
self
.
variables
]
+
[
y
.
op
for
y
in
self
.
node_order
]:
for
x
in
[
y
.
type
for
y
in
self
.
variables
]
+
[
y
.
op
for
y
in
self
.
node_order
]:
try
:
if
isinstance
(
x
,
CLinkerObject
):
with
suppress
(
MethodNotDefined
):
try
:
try
:
ret
+=
x
.
c_headers
(
c_compiler
)
ret
+=
x
.
c_headers
(
c_compiler
)
except
TypeError
:
except
TypeError
:
ret
+=
x
.
c_headers
()
ret
+=
x
.
c_headers
()
except
MethodNotDefined
:
pass
return
uniq
(
ret
)
return
uniq
(
ret
)
def
init_code
(
self
):
def
init_code
(
self
):
...
@@ -1083,15 +1076,15 @@ class CLinker(Linker):
...
@@ -1083,15 +1076,15 @@ class CLinker(Linker):
"""
"""
ret
=
[]
ret
=
[]
for
x
in
[
y
.
type
for
y
in
self
.
variables
]
+
[
y
.
op
for
y
in
self
.
node_order
]:
for
x
in
[
y
.
type
for
y
in
self
.
variables
]
+
[
y
.
op
for
y
in
self
.
node_order
]:
try
:
if
isinstance
(
x
,
CLinkerObject
):
with
suppress
(
MethodNotDefined
):
ret
+=
x
.
c_init_code
()
ret
+=
x
.
c_init_code
()
except
MethodNotDefined
:
pass
return
uniq
(
ret
)
return
uniq
(
ret
)
def
c_compiler
(
self
):
def
c_compiler
(
self
):
c_compiler
=
None
c_compiler
=
None
for
x
in
[
y
.
type
for
y
in
self
.
variables
]
+
[
y
.
op
for
y
in
self
.
node_order
]:
for
x
in
[
y
.
type
for
y
in
self
.
variables
]
+
[
y
.
op
for
y
in
self
.
node_order
]:
# FIXME: Why would a `Type` have a `c_compiler` field?!
if
hasattr
(
x
,
"c_compiler"
):
if
hasattr
(
x
,
"c_compiler"
):
x_compiler
=
x
.
c_compiler
()
x_compiler
=
x
.
c_compiler
()
else
:
else
:
...
@@ -1121,13 +1114,12 @@ class CLinker(Linker):
...
@@ -1121,13 +1114,12 @@ class CLinker(Linker):
ret
=
[]
ret
=
[]
c_compiler
=
self
.
c_compiler
()
c_compiler
=
self
.
c_compiler
()
for
x
in
[
y
.
type
for
y
in
self
.
variables
]
+
[
y
.
op
for
y
in
self
.
node_order
]:
for
x
in
[
y
.
type
for
y
in
self
.
variables
]
+
[
y
.
op
for
y
in
self
.
node_order
]:
try
:
if
isinstance
(
x
,
CLinkerObject
):
with
suppress
(
MethodNotDefined
):
try
:
try
:
ret
+=
x
.
c_header_dirs
(
c_compiler
)
ret
+=
x
.
c_header_dirs
(
c_compiler
)
except
TypeError
:
except
TypeError
:
ret
+=
x
.
c_header_dirs
()
ret
+=
x
.
c_header_dirs
()
except
MethodNotDefined
:
pass
# filter out empty strings/None
# filter out empty strings/None
return
[
r
for
r
in
uniq
(
ret
)
if
r
]
return
[
r
for
r
in
uniq
(
ret
)
if
r
]
...
@@ -1142,13 +1134,12 @@ class CLinker(Linker):
...
@@ -1142,13 +1134,12 @@ class CLinker(Linker):
ret
=
[]
ret
=
[]
c_compiler
=
self
.
c_compiler
()
c_compiler
=
self
.
c_compiler
()
for
x
in
[
y
.
type
for
y
in
self
.
variables
]
+
[
y
.
op
for
y
in
self
.
node_order
]:
for
x
in
[
y
.
type
for
y
in
self
.
variables
]
+
[
y
.
op
for
y
in
self
.
node_order
]:
try
:
if
isinstance
(
x
,
CLinkerObject
):
with
suppress
(
MethodNotDefined
):
try
:
try
:
ret
+=
x
.
c_libraries
(
c_compiler
)
ret
+=
x
.
c_libraries
(
c_compiler
)
except
TypeError
:
except
TypeError
:
ret
+=
x
.
c_libraries
()
ret
+=
x
.
c_libraries
()
except
MethodNotDefined
:
pass
return
uniq
(
ret
)
return
uniq
(
ret
)
def
lib_dirs
(
self
):
def
lib_dirs
(
self
):
...
@@ -1162,13 +1153,12 @@ class CLinker(Linker):
...
@@ -1162,13 +1153,12 @@ class CLinker(Linker):
ret
=
[]
ret
=
[]
c_compiler
=
self
.
c_compiler
()
c_compiler
=
self
.
c_compiler
()
for
x
in
[
y
.
type
for
y
in
self
.
variables
]
+
[
y
.
op
for
y
in
self
.
node_order
]:
for
x
in
[
y
.
type
for
y
in
self
.
variables
]
+
[
y
.
op
for
y
in
self
.
node_order
]:
try
:
if
isinstance
(
x
,
CLinkerObject
):
with
suppress
(
MethodNotDefined
):
try
:
try
:
ret
+=
x
.
c_lib_dirs
(
c_compiler
)
ret
+=
x
.
c_lib_dirs
(
c_compiler
)
except
TypeError
:
except
TypeError
:
ret
+=
x
.
c_lib_dirs
()
ret
+=
x
.
c_lib_dirs
()
except
MethodNotDefined
:
pass
# filter out empty strings/None
# filter out empty strings/None
return
[
r
for
r
in
uniq
(
ret
)
if
r
]
return
[
r
for
r
in
uniq
(
ret
)
if
r
]
...
@@ -1542,8 +1532,10 @@ class CLinker(Linker):
...
@@ -1542,8 +1532,10 @@ class CLinker(Linker):
if
hasattr
(
node
.
op
,
"__props__"
):
if
hasattr
(
node
.
op
,
"__props__"
):
version
.
append
(
node
.
op
.
__props__
)
version
.
append
(
node
.
op
.
__props__
)
for
i
in
node
.
inputs
:
for
i
in
node
.
inputs
:
if
isinstance
(
i
.
type
,
CLinkerObject
):
version
.
append
(
i
.
type
.
c_code_cache_version
())
version
.
append
(
i
.
type
.
c_code_cache_version
())
for
o
in
node
.
outputs
:
for
o
in
node
.
outputs
:
if
isinstance
(
o
.
type
,
CLinkerObject
):
version
.
append
(
o
.
type
.
c_code_cache_version
())
version
.
append
(
o
.
type
.
c_code_cache_version
())
# add the signature for this node
# add the signature for this node
...
...
theano/link/jax/jax_dispatch.py
浏览文件 @
f35aa65b
...
@@ -6,7 +6,6 @@ import jax
...
@@ -6,7 +6,6 @@ import jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
import
jax.scipy
as
jsp
import
jax.scipy
as
jsp
import
theano
from
theano.compile.ops
import
(
from
theano.compile.ops
import
(
DeepCopyOp
,
DeepCopyOp
,
Rebroadcast
,
Rebroadcast
,
...
@@ -17,6 +16,7 @@ from theano.compile.ops import (
...
@@ -17,6 +16,7 @@ from theano.compile.ops import (
)
)
from
theano.configdefaults
import
config
from
theano.configdefaults
import
config
from
theano.gof
import
FunctionGraph
from
theano.gof
import
FunctionGraph
from
theano.gof.type
import
CType
from
theano.ifelse
import
IfElse
from
theano.ifelse
import
IfElse
from
theano.scalar.basic
import
Cast
,
Clip
,
Composite
,
Identity
,
ScalarOp
,
Second
from
theano.scalar.basic
import
Cast
,
Clip
,
Composite
,
Identity
,
ScalarOp
,
Second
from
theano.scan.op
import
Scan
from
theano.scan.op
import
Scan
...
@@ -589,7 +589,7 @@ def jax_funcify_IfElse(op):
...
@@ -589,7 +589,7 @@ def jax_funcify_IfElse(op):
def
convert_indices
(
indices
,
entry
):
def
convert_indices
(
indices
,
entry
):
if
indices
and
isinstance
(
entry
,
theano
.
gof
.
Type
):
if
indices
and
isinstance
(
entry
,
C
Type
):
rval
=
indices
.
pop
(
0
)
rval
=
indices
.
pop
(
0
)
return
rval
return
rval
elif
isinstance
(
entry
,
slice
):
elif
isinstance
(
entry
,
slice
):
...
...
theano/scalar/basic.py
浏览文件 @
f35aa65b
...
@@ -25,7 +25,7 @@ from theano.gof.fg import FunctionGraph
...
@@ -25,7 +25,7 @@ from theano.gof.fg import FunctionGraph
from
theano.gof.graph
import
Apply
,
Constant
,
Variable
,
clone
,
list_of_nodes
from
theano.gof.graph
import
Apply
,
Constant
,
Variable
,
clone
,
list_of_nodes
from
theano.gof.op
import
COp
from
theano.gof.op
import
COp
from
theano.gof.opt
import
MergeOptimizer
from
theano.gof.opt
import
MergeOptimizer
from
theano.gof.type
import
Type
from
theano.gof.type
import
C
Type
from
theano.gof.utils
import
MetaObject
,
MethodNotDefined
from
theano.gof.utils
import
MetaObject
,
MethodNotDefined
from
theano.gradient
import
DisconnectedType
,
grad_undefined
from
theano.gradient
import
DisconnectedType
,
grad_undefined
from
theano.misc.safe_asarray
import
_asarray
from
theano.misc.safe_asarray
import
_asarray
...
@@ -313,7 +313,7 @@ def constant(x, name=None, dtype=None):
...
@@ -313,7 +313,7 @@ def constant(x, name=None, dtype=None):
return
ScalarConstant
(
get_scalar_type
(
str
(
x
.
dtype
)),
x
,
name
=
name
)
return
ScalarConstant
(
get_scalar_type
(
str
(
x
.
dtype
)),
x
,
name
=
name
)
class
Scalar
(
Type
):
class
Scalar
(
C
Type
):
"""
"""
Internal class, should not be used by clients.
Internal class, should not be used by clients.
...
@@ -1096,7 +1096,7 @@ class ScalarOp(COp):
...
@@ -1096,7 +1096,7 @@ class ScalarOp(COp):
if
hasattr
(
self
,
"output_types_preference"
):
if
hasattr
(
self
,
"output_types_preference"
):
variables
=
self
.
output_types_preference
(
*
types
)
variables
=
self
.
output_types_preference
(
*
types
)
if
not
isinstance
(
variables
,
(
list
,
tuple
))
or
any
(
if
not
isinstance
(
variables
,
(
list
,
tuple
))
or
any
(
not
isinstance
(
x
,
Type
)
for
x
in
variables
not
isinstance
(
x
,
C
Type
)
for
x
in
variables
):
):
raise
TypeError
(
raise
TypeError
(
"output_types_preference should return a list or a tuple of types"
,
"output_types_preference should return a list or a tuple of types"
,
...
...
theano/sparse/type.py
浏览文件 @
f35aa65b
...
@@ -32,7 +32,7 @@ def _is_sparse(x):
...
@@ -32,7 +32,7 @@ def _is_sparse(x):
return
isinstance
(
x
,
scipy
.
sparse
.
spmatrix
)
return
isinstance
(
x
,
scipy
.
sparse
.
spmatrix
)
class
SparseType
(
gof
.
Type
):
class
SparseType
(
gof
.
C
Type
):
"""
"""
Fundamental way to create a sparse node.
Fundamental way to create a sparse node.
...
...
theano/tensor/basic.py
浏览文件 @
f35aa65b
...
@@ -565,7 +565,7 @@ def get_scalar_constant_value(
...
@@ -565,7 +565,7 @@ def get_scalar_constant_value(
var
.
ndim
==
0
for
var
in
v
.
owner
.
inputs
[
0
]
.
owner
.
inputs
[
1
:]
var
.
ndim
==
0
for
var
in
v
.
owner
.
inputs
[
0
]
.
owner
.
inputs
[
1
:]
):
):
idx
=
v
.
owner
.
op
.
idx_list
[
0
]
idx
=
v
.
owner
.
op
.
idx_list
[
0
]
if
isinstance
(
idx
,
gof
.
Type
):
if
isinstance
(
idx
,
gof
.
C
Type
):
idx
=
get_scalar_constant_value
(
idx
=
get_scalar_constant_value
(
v
.
owner
.
inputs
[
1
],
max_recur
=
max_recur
v
.
owner
.
inputs
[
1
],
max_recur
=
max_recur
)
)
...
@@ -579,7 +579,7 @@ def get_scalar_constant_value(
...
@@ -579,7 +579,7 @@ def get_scalar_constant_value(
var
.
ndim
==
1
for
var
in
v
.
owner
.
inputs
[
0
]
.
owner
.
inputs
[
1
:]
var
.
ndim
==
1
for
var
in
v
.
owner
.
inputs
[
0
]
.
owner
.
inputs
[
1
:]
):
):
idx
=
v
.
owner
.
op
.
idx_list
[
0
]
idx
=
v
.
owner
.
op
.
idx_list
[
0
]
if
isinstance
(
idx
,
gof
.
Type
):
if
isinstance
(
idx
,
gof
.
C
Type
):
idx
=
get_scalar_constant_value
(
idx
=
get_scalar_constant_value
(
v
.
owner
.
inputs
[
1
],
max_recur
=
max_recur
v
.
owner
.
inputs
[
1
],
max_recur
=
max_recur
)
)
...
@@ -616,7 +616,7 @@ def get_scalar_constant_value(
...
@@ -616,7 +616,7 @@ def get_scalar_constant_value(
):
):
idx
=
v
.
owner
.
op
.
idx_list
[
0
]
idx
=
v
.
owner
.
op
.
idx_list
[
0
]
if
isinstance
(
idx
,
gof
.
Type
):
if
isinstance
(
idx
,
gof
.
C
Type
):
idx
=
get_scalar_constant_value
(
idx
=
get_scalar_constant_value
(
v
.
owner
.
inputs
[
1
],
max_recur
=
max_recur
v
.
owner
.
inputs
[
1
],
max_recur
=
max_recur
)
)
...
@@ -638,7 +638,7 @@ def get_scalar_constant_value(
...
@@ -638,7 +638,7 @@ def get_scalar_constant_value(
op
=
owner
.
op
op
=
owner
.
op
idx_list
=
op
.
idx_list
idx_list
=
op
.
idx_list
idx
=
idx_list
[
0
]
idx
=
idx_list
[
0
]
if
isinstance
(
idx
,
gof
.
Type
):
if
isinstance
(
idx
,
gof
.
C
Type
):
idx
=
get_scalar_constant_value
(
idx
=
get_scalar_constant_value
(
owner
.
inputs
[
1
],
max_recur
=
max_recur
owner
.
inputs
[
1
],
max_recur
=
max_recur
)
)
...
...
theano/tensor/random/type.py
浏览文件 @
f35aa65b
...
@@ -3,10 +3,10 @@ import sys
...
@@ -3,10 +3,10 @@ import sys
import
numpy
as
np
import
numpy
as
np
import
theano
import
theano
from
theano.gof.type
import
Type
from
theano.gof.type
import
C
Type
class
RandomStateType
(
Type
):
class
RandomStateType
(
C
Type
):
"""A Type wrapper for `numpy.random.RandomState`.
"""A Type wrapper for `numpy.random.RandomState`.
The reason this exists (and `Generic` doesn't suffice) is that
The reason this exists (and `Generic` doesn't suffice) is that
...
...
theano/tensor/subtensor.py
浏览文件 @
f35aa65b
...
@@ -7,13 +7,13 @@ from textwrap import dedent
...
@@ -7,13 +7,13 @@ from textwrap import dedent
import
numpy
as
np
import
numpy
as
np
import
theano
import
theano
from
theano
import
gof
from
theano
import
scalar
as
scal
from
theano
import
scalar
as
scal
from
theano.configdefaults
import
config
from
theano.configdefaults
import
config
from
theano.gof
import
MethodNotDefined
,
ParamsType
from
theano.gof.graph
import
Apply
,
Variable
from
theano.gof.graph
import
Apply
from
theano.gof.op
import
COp
,
Op
from
theano.gof.op
import
COp
,
Op
from
theano.gof.type
import
Type
from
theano.gof.params_type
import
ParamsType
from
theano.gof.type
import
CType
from
theano.gof.utils
import
MethodNotDefined
from
theano.gradient
import
DisconnectedType
from
theano.gradient
import
DisconnectedType
from
theano.misc.safe_asarray
import
_asarray
from
theano.misc.safe_asarray
import
_asarray
from
theano.printing
import
pprint
from
theano.printing
import
pprint
...
@@ -56,7 +56,7 @@ def as_index_constant(a):
...
@@ -56,7 +56,7 @@ def as_index_constant(a):
)
)
elif
isinstance
(
a
,
(
int
,
np
.
integer
)):
elif
isinstance
(
a
,
(
int
,
np
.
integer
)):
return
scal
.
ScalarConstant
(
scal
.
int64
,
a
)
return
scal
.
ScalarConstant
(
scal
.
int64
,
a
)
elif
not
isinstance
(
a
,
theano
.
tensor
.
Variable
):
elif
not
isinstance
(
a
,
Variable
):
return
theano
.
tensor
.
as_tensor
(
a
)
return
theano
.
tensor
.
as_tensor
(
a
)
else
:
else
:
return
a
return
a
...
@@ -82,7 +82,7 @@ def get_idx_list(inputs, idx_list, get_count=False):
...
@@ -82,7 +82,7 @@ def get_idx_list(inputs, idx_list, get_count=False):
# General case
# General case
def
convert
(
entry
):
def
convert
(
entry
):
if
isinstance
(
entry
,
gof
.
Type
):
if
isinstance
(
entry
,
C
Type
):
return
indices
.
pop
()
return
indices
.
pop
()
elif
isinstance
(
entry
,
slice
):
elif
isinstance
(
entry
,
slice
):
return
slice
(
convert
(
entry
.
start
),
convert
(
entry
.
stop
),
convert
(
entry
.
step
))
return
slice
(
convert
(
entry
.
start
),
convert
(
entry
.
stop
),
convert
(
entry
.
step
))
...
@@ -115,7 +115,7 @@ def get_canonical_form_slice(theslice, length):
...
@@ -115,7 +115,7 @@ def get_canonical_form_slice(theslice, length):
try
:
try
:
x_constant
=
get_scalar_constant_value
(
x
)
x_constant
=
get_scalar_constant_value
(
x
)
is_constant
=
True
is_constant
=
True
except
theano
.
tensor
.
NotScalarConstantError
:
except
NotScalarConstantError
:
x_constant
=
theano
.
tensor
.
extract_constant
(
x
)
x_constant
=
theano
.
tensor
.
extract_constant
(
x
)
is_constant
=
False
is_constant
=
False
return
x_constant
,
is_constant
return
x_constant
,
is_constant
...
@@ -487,30 +487,30 @@ class Subtensor(COp):
...
@@ -487,30 +487,30 @@ class Subtensor(COp):
)
)
if
(
if
(
isinstance
(
entry
,
(
np
.
ndarray
,
theano
.
tensor
.
Variable
))
isinstance
(
entry
,
(
np
.
ndarray
,
Variable
))
and
hasattr
(
entry
,
"dtype"
)
and
hasattr
(
entry
,
"dtype"
)
and
entry
.
dtype
==
"bool"
and
entry
.
dtype
==
"bool"
):
):
raise
AdvancedIndexingError
(
"Invalid index type or slice for Subtensor"
)
raise
AdvancedIndexingError
(
"Invalid index type or slice for Subtensor"
)
if
isinstance
(
entry
,
gof
.
Variable
)
and
(
if
isinstance
(
entry
,
Variable
)
and
(
entry
.
type
in
invalid_scal_types
or
entry
.
type
in
invalid_tensor_types
entry
.
type
in
invalid_scal_types
or
entry
.
type
in
invalid_tensor_types
):
):
raise
TypeError
(
"Expected an integer"
)
raise
TypeError
(
"Expected an integer"
)
if
isinstance
(
entry
,
gof
.
Variable
)
and
entry
.
type
in
scal_types
:
if
isinstance
(
entry
,
Variable
)
and
entry
.
type
in
scal_types
:
return
entry
.
type
return
entry
.
type
elif
isinstance
(
entry
,
gof
.
Type
)
and
entry
in
scal_types
:
elif
isinstance
(
entry
,
C
Type
)
and
entry
in
scal_types
:
return
entry
return
entry
if
(
if
(
isinstance
(
entry
,
gof
.
Variable
)
isinstance
(
entry
,
Variable
)
and
entry
.
type
in
tensor_types
and
entry
.
type
in
tensor_types
and
np
.
all
(
entry
.
type
.
broadcastable
)
and
np
.
all
(
entry
.
type
.
broadcastable
)
):
):
return
scal
.
get_scalar_type
(
entry
.
type
.
dtype
)
return
scal
.
get_scalar_type
(
entry
.
type
.
dtype
)
elif
(
elif
(
isinstance
(
entry
,
gof
.
Type
)
isinstance
(
entry
,
C
Type
)
and
entry
in
tensor_types
and
entry
in
tensor_types
and
np
.
all
(
entry
.
broadcastable
)
and
np
.
all
(
entry
.
broadcastable
)
):
):
...
@@ -553,7 +553,7 @@ class Subtensor(COp):
...
@@ -553,7 +553,7 @@ class Subtensor(COp):
"""
"""
Return the idx_list with constant inputs replaced by their
Return the idx_list with constant inputs replaced by their
python scalar equivalent.
python scalar equivalent.
May raise `
theano.tensor.
NotScalarConstantError` if the idx contains
May raise `NotScalarConstantError` if the idx contains
non-constant entries.
non-constant entries.
If allow_partial is True, then entries that are not constant will
If allow_partial is True, then entries that are not constant will
...
@@ -594,7 +594,7 @@ class Subtensor(COp):
...
@@ -594,7 +594,7 @@ class Subtensor(COp):
only_process_constants
=
only_process_constants
,
only_process_constants
=
only_process_constants
,
elemwise
=
elemwise
,
elemwise
=
elemwise
,
)
)
except
theano
.
tensor
.
NotScalarConstantError
:
except
NotScalarConstantError
:
if
allow_partial
:
if
allow_partial
:
return
val
return
val
else
:
else
:
...
@@ -610,7 +610,7 @@ class Subtensor(COp):
...
@@ -610,7 +610,7 @@ class Subtensor(COp):
# Since scal.as_scalar does not know about tensor types (it would
# Since scal.as_scalar does not know about tensor types (it would
# create a circular import) , this method converts either a
# create a circular import) , this method converts either a
# TensorVariable or a ScalarVariable to a scalar.
# TensorVariable or a ScalarVariable to a scalar.
if
isinstance
(
a
,
gof
.
Variable
)
and
isinstance
(
a
.
type
,
TensorType
):
if
isinstance
(
a
,
Variable
)
and
isinstance
(
a
.
type
,
TensorType
):
return
theano
.
tensor
.
scalar_from_tensor
(
a
)
return
theano
.
tensor
.
scalar_from_tensor
(
a
)
else
:
else
:
return
scal
.
as_scalar
(
a
)
return
scal
.
as_scalar
(
a
)
...
@@ -633,7 +633,7 @@ class Subtensor(COp):
...
@@ -633,7 +633,7 @@ class Subtensor(COp):
raise
IndexError
(
"too many indices for array"
)
raise
IndexError
(
"too many indices for array"
)
input_types
=
Subtensor
.
collapse
(
input_types
=
Subtensor
.
collapse
(
idx_list
,
lambda
entry
:
isinstance
(
entry
,
gof
.
Type
)
idx_list
,
lambda
entry
:
isinstance
(
entry
,
C
Type
)
)
)
if
len
(
inputs
)
!=
len
(
input_types
):
if
len
(
inputs
)
!=
len
(
input_types
):
raise
IndexError
(
raise
IndexError
(
...
@@ -672,7 +672,7 @@ class Subtensor(COp):
...
@@ -672,7 +672,7 @@ class Subtensor(COp):
broadcastable
.
append
(
False
)
broadcastable
.
append
(
False
)
return
gof
.
Apply
(
return
Apply
(
self
,
self
,
(
x
,)
+
inputs
,
(
x
,)
+
inputs
,
[
theano
.
tensor
.
tensor
(
dtype
=
x
.
type
.
dtype
,
broadcastable
=
broadcastable
)],
[
theano
.
tensor
.
tensor
(
dtype
=
x
.
type
.
dtype
,
broadcastable
=
broadcastable
)],
...
@@ -851,7 +851,7 @@ class Subtensor(COp):
...
@@ -851,7 +851,7 @@ class Subtensor(COp):
inc_spec_pos
(
1
)
inc_spec_pos
(
1
)
if
depth
==
0
:
if
depth
==
0
:
is_slice
.
append
(
0
)
is_slice
.
append
(
0
)
elif
isinstance
(
entry
,
Type
):
elif
isinstance
(
entry
,
C
Type
):
init_cmds
.
append
(
init_cmds
.
append
(
"subtensor_spec[
%
i] =
%
s;"
%
(
spec_pos
(),
inputs
[
input_pos
()])
"subtensor_spec[
%
i] =
%
s;"
%
(
spec_pos
(),
inputs
[
input_pos
()])
)
)
...
@@ -1050,7 +1050,7 @@ class Subtensor(COp):
...
@@ -1050,7 +1050,7 @@ class Subtensor(COp):
return
(
9
,)
return
(
9
,)
def
c_code
(
self
,
node
,
name
,
inputs
,
outputs
,
sub
):
# DEBUG
def
c_code
(
self
,
node
,
name
,
inputs
,
outputs
,
sub
):
# DEBUG
if
not
isinstance
(
node
.
inputs
[
0
]
.
type
,
theano
.
tensor
.
TensorType
):
if
not
isinstance
(
node
.
inputs
[
0
]
.
type
,
TensorType
):
raise
NotImplementedError
()
raise
NotImplementedError
()
x
=
inputs
[
0
]
x
=
inputs
[
0
]
...
@@ -1469,7 +1469,7 @@ class IncSubtensor(COp):
...
@@ -1469,7 +1469,7 @@ class IncSubtensor(COp):
raise
IndexError
(
"too many indices for array"
)
raise
IndexError
(
"too many indices for array"
)
input_types
=
Subtensor
.
collapse
(
input_types
=
Subtensor
.
collapse
(
idx_list
,
lambda
entry
:
isinstance
(
entry
,
gof
.
Type
)
idx_list
,
lambda
entry
:
isinstance
(
entry
,
C
Type
)
)
)
if
len
(
inputs
)
!=
len
(
input_types
):
if
len
(
inputs
)
!=
len
(
input_types
):
raise
IndexError
(
raise
IndexError
(
...
@@ -1482,7 +1482,7 @@ class IncSubtensor(COp):
...
@@ -1482,7 +1482,7 @@ class IncSubtensor(COp):
%
(
input
.
type
,
expected_type
)
%
(
input
.
type
,
expected_type
)
)
)
return
gof
.
Apply
(
self
,
(
x
,
y
)
+
inputs
,
[
x
.
type
()])
return
Apply
(
self
,
(
x
,
y
)
+
inputs
,
[
x
.
type
()])
def
decl_view
(
self
):
def
decl_view
(
self
):
return
"PyArrayObject * zview = NULL;"
return
"PyArrayObject * zview = NULL;"
...
@@ -1493,7 +1493,7 @@ class IncSubtensor(COp):
...
@@ -1493,7 +1493,7 @@ class IncSubtensor(COp):
indices
=
list
(
reversed
(
inputs
[
2
:]))
indices
=
list
(
reversed
(
inputs
[
2
:]))
def
convert
(
entry
):
def
convert
(
entry
):
if
isinstance
(
entry
,
gof
.
Type
):
if
isinstance
(
entry
,
C
Type
):
return
indices
.
pop
()
return
indices
.
pop
()
elif
isinstance
(
entry
,
slice
):
elif
isinstance
(
entry
,
slice
):
return
slice
(
return
slice
(
...
@@ -1645,7 +1645,7 @@ class IncSubtensor(COp):
...
@@ -1645,7 +1645,7 @@ class IncSubtensor(COp):
"""
"""
if
not
isinstance
(
node
.
inputs
[
0
]
.
type
,
theano
.
tensor
.
TensorType
):
if
not
isinstance
(
node
.
inputs
[
0
]
.
type
,
TensorType
):
raise
NotImplementedError
()
raise
NotImplementedError
()
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
...
@@ -2239,9 +2239,9 @@ def as_index_variable(idx):
...
@@ -2239,9 +2239,9 @@ def as_index_variable(idx):
return
NoneConst
.
clone
()
return
NoneConst
.
clone
()
if
isinstance
(
idx
,
slice
):
if
isinstance
(
idx
,
slice
):
return
make_slice
(
idx
)
return
make_slice
(
idx
)
if
isinstance
(
idx
,
gof
.
Variable
)
and
isinstance
(
idx
.
type
,
SliceType
):
if
isinstance
(
idx
,
Variable
)
and
isinstance
(
idx
.
type
,
SliceType
):
return
idx
return
idx
if
isinstance
(
idx
,
gof
.
Variable
)
and
isinstance
(
idx
.
type
,
NoneTypeT
):
if
isinstance
(
idx
,
Variable
)
and
isinstance
(
idx
.
type
,
NoneTypeT
):
return
idx
return
idx
idx
=
theano
.
tensor
.
as_tensor_variable
(
idx
)
idx
=
theano
.
tensor
.
as_tensor_variable
(
idx
)
if
idx
.
type
.
dtype
not
in
theano
.
tensor
.
discrete_dtypes
:
if
idx
.
type
.
dtype
not
in
theano
.
tensor
.
discrete_dtypes
:
...
@@ -2312,7 +2312,7 @@ class AdvancedSubtensor(Op):
...
@@ -2312,7 +2312,7 @@ class AdvancedSubtensor(Op):
for
i
in
indexed_result_shape
(
fake_shape
,
bcast_index
)
for
i
in
indexed_result_shape
(
fake_shape
,
bcast_index
)
]
]
return
gof
.
Apply
(
return
Apply
(
self
,
self
,
(
x
,)
+
index
,
(
x
,)
+
index
,
[
theano
.
tensor
.
tensor
(
dtype
=
x
.
type
.
dtype
,
broadcastable
=
bcast
)],
[
theano
.
tensor
.
tensor
(
dtype
=
x
.
type
.
dtype
,
broadcastable
=
bcast
)],
...
@@ -2415,7 +2415,7 @@ class AdvancedIncSubtensor(Op):
...
@@ -2415,7 +2415,7 @@ class AdvancedIncSubtensor(Op):
if
isinstance
(
inp
,
(
list
,
tuple
)):
if
isinstance
(
inp
,
(
list
,
tuple
)):
inp
=
theano
.
tensor
.
as_tensor_variable
(
inp
)
inp
=
theano
.
tensor
.
as_tensor_variable
(
inp
)
new_inputs
.
append
(
inp
)
new_inputs
.
append
(
inp
)
return
gof
.
Apply
(
return
Apply
(
self
,
self
,
(
x
,
y
)
+
tuple
(
new_inputs
),
(
x
,
y
)
+
tuple
(
new_inputs
),
[
[
...
...
theano/tensor/type.py
浏览文件 @
f35aa65b
...
@@ -7,14 +7,14 @@ import theano
...
@@ -7,14 +7,14 @@ import theano
from
theano
import
scalar
as
scal
from
theano
import
scalar
as
scal
from
theano.configdefaults
import
config
from
theano.configdefaults
import
config
from
theano.gof.graph
import
Variable
from
theano.gof.graph
import
Variable
from
theano.gof.type
import
Type
from
theano.gof.type
import
C
Type
from
theano.misc.safe_asarray
import
_asarray
from
theano.misc.safe_asarray
import
_asarray
_logger
=
logging
.
getLogger
(
"theano.tensor.type"
)
_logger
=
logging
.
getLogger
(
"theano.tensor.type"
)
class
TensorType
(
Type
):
class
TensorType
(
C
Type
):
"""
"""
Symbolic `Type` representing a numpy.ndarray value.
Symbolic `Type` representing a numpy.ndarray value.
...
...
theano/tensor/type_other.py
浏览文件 @
f35aa65b
...
@@ -7,7 +7,7 @@ import numpy as np
...
@@ -7,7 +7,7 @@ import numpy as np
import
theano
import
theano
from
theano.gof.graph
import
Apply
,
Constant
from
theano.gof.graph
import
Apply
,
Constant
from
theano.gof.op
import
Op
from
theano.gof.op
import
Op
from
theano.gof.type
import
Generic
,
Type
from
theano.gof.type
import
CType
,
Generic
from
theano.gradient
import
DisconnectedType
from
theano.gradient
import
DisconnectedType
...
@@ -49,7 +49,7 @@ class MakeSlice(Op):
...
@@ -49,7 +49,7 @@ class MakeSlice(Op):
make_slice
=
MakeSlice
()
make_slice
=
MakeSlice
()
class
SliceType
(
Type
):
class
SliceType
(
C
Type
):
def
filter
(
self
,
x
,
strict
=
False
,
allow_downcast
=
None
):
def
filter
(
self
,
x
,
strict
=
False
,
allow_downcast
=
None
):
if
isinstance
(
x
,
slice
):
if
isinstance
(
x
,
slice
):
return
x
return
x
...
...
theano/typed_list/type.py
浏览文件 @
f35aa65b
from
theano.gof.type
import
Type
from
theano.gof.type
import
CType
,
Type
class
TypedListType
(
Type
):
class
TypedListType
(
C
Type
):
"""
"""
Parameters
Parameters
...
...
theano/utils.py
浏览文件 @
f35aa65b
...
@@ -119,39 +119,6 @@ def get_unbound_function(unbound):
...
@@ -119,39 +119,6 @@ def get_unbound_function(unbound):
return
unbound
return
unbound
class
DefaultOrderedDict
(
OrderedDict
):
def
__init__
(
self
,
default_factory
=
None
,
*
a
,
**
kw
):
if
default_factory
is
not
None
and
not
isinstance
(
default_factory
,
Callable
):
raise
TypeError
(
"first argument must be callable"
)
OrderedDict
.
__init__
(
self
,
*
a
,
**
kw
)
self
.
default_factory
=
default_factory
def
__getitem__
(
self
,
key
):
try
:
return
OrderedDict
.
__getitem__
(
self
,
key
)
except
KeyError
:
return
self
.
__missing__
(
key
)
def
__missing__
(
self
,
key
):
if
self
.
default_factory
is
None
:
raise
KeyError
(
key
)
self
[
key
]
=
value
=
self
.
default_factory
()
return
value
def
__reduce__
(
self
):
if
self
.
default_factory
is
None
:
args
=
tuple
()
else
:
args
=
(
self
.
default_factory
,)
return
type
(
self
),
args
,
None
,
None
,
list
(
self
.
items
())
def
copy
(
self
):
return
self
.
__copy__
()
def
__copy__
(
self
):
return
type
(
self
)(
self
.
default_factory
,
self
)
def
maybe_add_to_os_environ_pathlist
(
var
,
newpath
):
def
maybe_add_to_os_environ_pathlist
(
var
,
newpath
):
"""Unfortunately, Conda offers to make itself the default Python
"""Unfortunately, Conda offers to make itself the default Python
and those who use it that way will probably not activate envs
and those who use it that way will probably not activate envs
...
@@ -377,22 +344,6 @@ def flatten(a):
...
@@ -377,22 +344,6 @@ def flatten(a):
return
[
a
]
return
[
a
]
class
NoDuplicateOptWarningFilter
(
logging
.
Filter
):
"""Filter to avoid duplicating optimization warnings."""
prev_msgs
=
set
()
def
filter
(
self
,
record
):
msg
=
record
.
getMessage
()
if
msg
.
startswith
(
"Optimization Warning: "
):
if
msg
in
self
.
prev_msgs
:
return
False
else
:
self
.
prev_msgs
.
add
(
msg
)
return
True
return
True
def
apply_across_args
(
*
fns
):
def
apply_across_args
(
*
fns
):
"""Create new functions that distributes the wrapped functions across iterable arguments.
"""Create new functions that distributes the wrapped functions across iterable arguments.
...
@@ -418,3 +369,85 @@ def apply_across_args(*fns):
...
@@ -418,3 +369,85 @@ def apply_across_args(*fns):
return
partial
(
f2
,
fns
[
0
])
return
partial
(
f2
,
fns
[
0
])
else
:
else
:
return
[
partial
(
f2
,
f
)
for
f
in
fns
]
return
[
partial
(
f2
,
f
)
for
f
in
fns
]
class
NoDuplicateOptWarningFilter
(
logging
.
Filter
):
"""Filter to avoid duplicating optimization warnings."""
prev_msgs
=
set
()
def
filter
(
self
,
record
):
msg
=
record
.
getMessage
()
if
msg
.
startswith
(
"Optimization Warning: "
):
if
msg
in
self
.
prev_msgs
:
return
False
else
:
self
.
prev_msgs
.
add
(
msg
)
return
True
return
True
class
Singleton
:
"""Convenient base class for a singleton.
It saves having to implement __eq__ and __hash__.
"""
__instance
=
None
def
__new__
(
cls
):
# If sub-subclass of SingletonType don't redeclare __instance
# when we look for it, we will find it in the subclass. We
# don't want that, so we check the class. When we add one, we
# add one only to the current class, so all is working
# correctly.
if
cls
.
__instance
is
None
or
not
isinstance
(
cls
.
__instance
,
cls
):
cls
.
__instance
=
super
()
.
__new__
(
cls
)
return
cls
.
__instance
def
__str__
(
self
):
return
self
.
__class__
.
__name__
def
__eq__
(
self
,
other
):
if
self
is
other
:
return
True
if
type
(
self
)
is
type
(
other
):
return
True
return
False
def
__hash__
(
self
):
return
hash
(
type
(
self
))
class
DefaultOrderedDict
(
OrderedDict
):
def
__init__
(
self
,
default_factory
=
None
,
*
a
,
**
kw
):
if
default_factory
is
not
None
and
not
isinstance
(
default_factory
,
Callable
):
raise
TypeError
(
"first argument must be callable"
)
OrderedDict
.
__init__
(
self
,
*
a
,
**
kw
)
self
.
default_factory
=
default_factory
def
__getitem__
(
self
,
key
):
try
:
return
OrderedDict
.
__getitem__
(
self
,
key
)
except
KeyError
:
return
self
.
__missing__
(
key
)
def
__missing__
(
self
,
key
):
if
self
.
default_factory
is
None
:
raise
KeyError
(
key
)
self
[
key
]
=
value
=
self
.
default_factory
()
return
value
def
__reduce__
(
self
):
if
self
.
default_factory
is
None
:
args
=
tuple
()
else
:
args
=
(
self
.
default_factory
,)
return
type
(
self
),
args
,
None
,
None
,
list
(
self
.
items
())
def
copy
(
self
):
return
self
.
__copy__
()
def
__copy__
(
self
):
return
type
(
self
)(
self
.
default_factory
,
self
)
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论