提交 f35aa65b authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename PureType to Type and Type to CType

上级 6c6d81c6
...@@ -25,7 +25,7 @@ input variables and place the variables in the output variables. ...@@ -25,7 +25,7 @@ input variables and place the variables in the output variables.
What needs to be defined What needs to be defined
======================== ========================
There are less methods to define for an `COp` than for a Type: There are less methods to define for a `COp` than for a `Type`:
.. class:: COp .. class:: COp
...@@ -213,9 +213,9 @@ There are less methods to define for an `COp` than for a Type: ...@@ -213,9 +213,9 @@ There are less methods to define for an `COp` than for a Type:
Optional. If present this method will be called before doing Optional. If present this method will be called before doing
constant folding of a node, with that node as a parameter. If constant folding of a node, with that node as a parameter. If
it return True, we will not generate c code when doing constant it return True, we will not generate C code when doing constant
folding of this node. This is useful when the compilation of folding of this node. This is useful when the compilation of
the c code will be longer then the computation in python the C code will be longer then the computation in python
(e.g. Elemwise of scalars). (e.g. Elemwise of scalars).
In addition, this allow to lower the number of compiled module In addition, this allow to lower the number of compiled module
...@@ -233,7 +233,7 @@ There are less methods to define for an `COp` than for a Type: ...@@ -233,7 +233,7 @@ There are less methods to define for an `COp` than for a Type:
considered the same as if the method was not defined. considered the same as if the method was not defined.
If this method is defined and does not return `None`, then the If this method is defined and does not return `None`, then the
Op *must* have a `params_type` property with the Type to use `Op` *must* have a `params_type` property with the `Type` to use
for the params variable. for the params variable.
.. attribute:: _f16_ok .. attribute:: _f16_ok
...@@ -252,7 +252,7 @@ There are less methods to define for an `COp` than for a Type: ...@@ -252,7 +252,7 @@ There are less methods to define for an `COp` than for a Type:
developpment if a better solution is found. developpment if a better solution is found.
The ``name`` argument is currently given an invalid value, so steer The ``name`` argument is currently given an invalid value, so steer
away from it. As was the case with Type, ``sub['fail']`` provides away from it. As was the case with `Type`, ``sub['fail']`` provides
failure code that you *must* use if you want to raise an exception, failure code that you *must* use if you want to raise an exception,
after setting the exception message. after setting the exception message.
......
...@@ -6,7 +6,7 @@ Implementing double in C ...@@ -6,7 +6,7 @@ Implementing double in C
======================== ========================
The previous two sections described how to define a double :ref:`type` The previous two sections described how to define a double :ref:`type`
and arithmetic operations on that Type, but all of them were and arithmetic operations on that `Type`, but all of them were
implemented in pure Python. In this section we will see how to define implemented in pure Python. In this section we will see how to define
the double type in such a way that it can be used by operations the double type in such a way that it can be used by operations
implemented in C (which we will define in the section after that). implemented in C (which we will define in the section after that).
...@@ -15,15 +15,15 @@ implemented in C (which we will define in the section after that). ...@@ -15,15 +15,15 @@ implemented in C (which we will define in the section after that).
How does it work? How does it work?
================= =================
In order to be C-compatible, a Type must provide a C interface to the In order to be C-compatible, a `Type` must provide a C interface to the
Python data that satisfy the constraints it puts forward. In other Python data that satisfy the constraints it puts forward. In other
words, it must define C code that can convert a Python reference into words, it must define C code that can convert a Python reference into
some type suitable for manipulation in C and it must define C code some type suitable for manipulation in C and it must define C code
that can convert some C structure in which the C implementation of an that can convert some C structure in which the C implementation of an
operation stores its variables into a reference to an object that can be operation stores its variables into a reference to an object that can be
used from Python and is a valid value for the Type. used from Python and is a valid value for the `Type`.
For example, in the current example, we have a Type which represents a For example, in the current example, we have a `Type` which represents a
Python float. First, we will choose a corresponding C type. The Python float. First, we will choose a corresponding C type. The
natural choice would be the primitive ``double`` type. Then, we need natural choice would be the primitive ``double`` type. Then, we need
to write code that will take a ``PyObject*``, check that it is a to write code that will take a ``PyObject*``, check that it is a
...@@ -42,10 +42,10 @@ find here_. ...@@ -42,10 +42,10 @@ find here_.
What needs to be defined What needs to be defined
======================== ========================
In order to be C-compatible, a Type must define several additional In order to be C-compatible, the `Type` subclass interface `CType` must be used.
methods, which all start with the ``c_`` prefix. The complete list can It defines several additional methods, which all start with the ``c_``
be found in the documentation for :class:`.gof.type.Type`. Here, we'll focus on prefix. The complete list can be found in the documentation for
the most important ones: :class:`.gof.type.CType`. Here, we'll focus on the most important ones:
.. class:: CLinkerType .. class:: CLinkerType
...@@ -144,7 +144,7 @@ the most important ones: ...@@ -144,7 +144,7 @@ the most important ones:
Each of these functions take two arguments, ``name`` and ``sub`` which Each of these functions take two arguments, ``name`` and ``sub`` which
must be used to parameterize the C code they return. ``name`` is a must be used to parameterize the C code they return. ``name`` is a
string which is chosen by the compiler to represent a :ref:`variable` of string which is chosen by the compiler to represent a :ref:`variable` of
the Type in such a way that there are no name conflicts between the `CType` in such a way that there are no name conflicts between
different pieces of data. Therefore, all variables declared in different pieces of data. Therefore, all variables declared in
``c_declare`` should have a name which includes ``name``. Furthermore, ``c_declare`` should have a name which includes ``name``. Furthermore,
the name of the variable containing a pointer to the Python object the name of the variable containing a pointer to the Python object
...@@ -180,20 +180,19 @@ out: ...@@ -180,20 +180,19 @@ out:
Defining the methods Defining the methods
==================== ====================
.. testsetup::
import theano
double = theano.Type()
**c_declare** **c_declare**
.. testcode:: .. testcode::
def c_declare(name, sub): from theano.gof.type import Generic
class double(Generic):
def c_declare(self, name, sub, check_input=True):
return """ return """
double %(name)s; double %(name)s;
""" % dict(name = name) """ % dict(name = name)
double.c_declare = c_declare
Very straightforward. All we need to do is write C code to declare a Very straightforward. All we need to do is write C code to declare a
double. That double will be named whatever is passed to our function double. That double will be named whatever is passed to our function
...@@ -211,7 +210,7 @@ here). Also note that you cannot declare a variable called ...@@ -211,7 +210,7 @@ here). Also note that you cannot declare a variable called
them. them.
What you declare there is basically the C interface you are giving to What you declare there is basically the C interface you are giving to
your Type. If you wish people to develop operations that make use of your `CType`. If you wish people to develop operations that make use of
it, it's best to publish it somewhere. it, it's best to publish it somewhere.
...@@ -219,11 +218,10 @@ it, it's best to publish it somewhere. ...@@ -219,11 +218,10 @@ it, it's best to publish it somewhere.
.. testcode:: .. testcode::
def c_init(name, sub): def c_init(self, name, sub):
return """ return """
%(name)s = 0.0; %(name)s = 0.0;
""" % dict(name = name) """ % dict(name = name)
double.c_init = c_init
This function has to initialize the This function has to initialize the
double we declared previously to a suitable value. This is useful if double we declared previously to a suitable value. This is useful if
...@@ -245,7 +243,7 @@ called, without knowing for sure which of the two. ...@@ -245,7 +243,7 @@ called, without knowing for sure which of the two.
.. testcode:: .. testcode::
def c_extract(name, sub): def c_extract(self, name, sub, check_input=True):
return """ return """
if (!PyFloat_Check(py_%(name)s)) { if (!PyFloat_Check(py_%(name)s)) {
PyErr_SetString(PyExc_TypeError, "expected a float"); PyErr_SetString(PyExc_TypeError, "expected a float");
...@@ -253,7 +251,6 @@ called, without knowing for sure which of the two. ...@@ -253,7 +251,6 @@ called, without knowing for sure which of the two.
} }
%(name)s = PyFloat_AsDouble(py_%(name)s); %(name)s = PyFloat_AsDouble(py_%(name)s);
""" % dict(name = name, fail = sub['fail']) """ % dict(name = name, fail = sub['fail'])
double.c_extract = c_extract
This method is slightly more sophisticated. What happens here is that This method is slightly more sophisticated. What happens here is that
we have a reference to a Python object which Theano has placed in we have a reference to a Python object which Theano has placed in
...@@ -469,9 +466,9 @@ Final version ...@@ -469,9 +466,9 @@ Final version
.. testcode:: .. testcode::
from theano import gof from theano.gof.type import
class Double(gof.Type): class Double(Type):
def filter(self, x, strict=False, allow_downcast=None): def filter(self, x, strict=False, allow_downcast=None):
if strict and not isinstance(x, float): if strict and not isinstance(x, float):
......
...@@ -17,7 +17,7 @@ write a new one. Don't worry, Theano was designed to make it easy to add new ...@@ -17,7 +17,7 @@ write a new one. Don't worry, Theano was designed to make it easy to add new
Ops, Types, and Optimizations. Ops, Types, and Optimizations.
.. These first few pages will walk you through the definition of a new :ref:`type`, .. These first few pages will walk you through the definition of a new :ref:`type`,
.. ``double``, and a basic arithmetic :ref:`operations <op>` on that Type. .. ``double``, and a basic arithmetic :ref:`operations <op>` on that `Type`.
As an illustration, this tutorial shows how to write a simple Python-based As an illustration, this tutorial shows how to write a simple Python-based
:ref:`operations <op>` which performs operations on :ref:`operations <op>` which performs operations on
...@@ -134,7 +134,7 @@ or :func:`make_thunk`. ...@@ -134,7 +134,7 @@ or :func:`make_thunk`.
- it operates on the Variables found in - it operates on the Variables found in
``*inputs`` in Theano's symbolic language to infer the type of ``*inputs`` in Theano's symbolic language to infer the type of
the symbolic output Variables. It creates output Variables of a suitable the symbolic output Variables. It creates output Variables of a suitable
symbolic Type to serve as the outputs of this op's symbolic `Type` to serve as the outputs of this op's
application. application.
- it creates an Apply instance with the input and output Variable, and - it creates an Apply instance with the input and output Variable, and
return the Apply instance. return the Apply instance.
...@@ -397,7 +397,7 @@ A common and easy way to ensure inputs are variables is to run them through ...@@ -397,7 +397,7 @@ A common and easy way to ensure inputs are variables is to run them through
``as_tensor_variable``. This function leaves TensorType variables alone, raises ``as_tensor_variable``. This function leaves TensorType variables alone, raises
an error for non-TensorType variables, and copies any ``numpy.ndarray`` into an error for non-TensorType variables, and copies any ``numpy.ndarray`` into
the storage for a TensorType Constant. The ``make_node`` method dictates the the storage for a TensorType Constant. The ``make_node`` method dictates the
appropriate Type for all output variables. appropriate `Type` for all output variables.
The ``perform`` method implements the Op's mathematical logic in Python. The ``perform`` method implements the Op's mathematical logic in Python.
The inputs (here ``x``) are passed by value, but a single output is returned The inputs (here ``x``) are passed by value, but a single output is returned
......
...@@ -50,7 +50,7 @@ define the following methods. ...@@ -50,7 +50,7 @@ define the following methods.
.. function:: make_node(*inputs) .. function:: make_node(*inputs)
This method is responsible for creating output Variables of a This method is responsible for creating output Variables of a
suitable symbolic Type to serve as the outputs of this Op's suitable symbolic `Type` to serve as the outputs of this Op's
application. The Variables found in ``*inputs`` must be operated on application. The Variables found in ``*inputs`` must be operated on
using Theano's symbolic language to compute the symbolic output using Theano's symbolic language to compute the symbolic output
Variables. This method should put these outputs into an Apply Variables. This method should put these outputs into an Apply
...@@ -769,7 +769,7 @@ as first argument to Apply. We define ``perform`` using the function ...@@ -769,7 +769,7 @@ as first argument to Apply. We define ``perform`` using the function
``fn`` passed in the constructor. ``fn`` passed in the constructor.
This design is a flexible way to define basic operations without This design is a flexible way to define basic operations without
duplicating code. The same way a Type subclass represents a set of duplicating code. The same way a `Type` subclass represents a set of
structurally similar types (see previous section), an `Op` subclass structurally similar types (see previous section), an `Op` subclass
represents a set of structurally similar operations: operations that represents a set of structurally similar operations: operations that
have the same input/output types, operations that only differ in one have the same input/output types, operations that only differ in one
......
...@@ -266,8 +266,8 @@ along with pointers to the relevant documentation. ...@@ -266,8 +266,8 @@ along with pointers to the relevant documentation.
primitive type. The C type associated with this Theano type is the primitive type. The C type associated with this Theano type is the
represented C primitive itself. represented C primitive itself.
* :ref:`SparseType <sparse_ops>` : Theano type used to represent sparse * :ref:`SparseType <sparse_ops>` : Theano `Type` used to represent sparse
tensors. There is no equivalent C type for this Theano Type but you tensors. There is no equivalent C type for this Theano `Type` but you
can split a sparse variable into its parts as TensorVariables. Those can split a sparse variable into its parts as TensorVariables. Those
can then be used as inputs to an op with C code. can then be used as inputs to an op with C code.
......
...@@ -10,7 +10,7 @@ Making the double type ...@@ -10,7 +10,7 @@ Making the double type
Type's contract Type's contract
=============== ===============
In Theano's framework, a ``Type`` (:class:`.gof.type.Type`) In Theano's framework, a ``Type`` (:class:`Type`)
is any object which defines the following is any object which defines the following
methods. To obtain the default methods described below, the Type should methods. To obtain the default methods described below, the Type should
be an instance of ``Type`` or should be an instance of a be an instance of ``Type`` or should be an instance of a
...@@ -22,7 +22,7 @@ i.e. the same default argument names and values. If you wish to add ...@@ -22,7 +22,7 @@ i.e. the same default argument names and values. If you wish to add
extra arguments to any of these methods, these extra arguments must have extra arguments to any of these methods, these extra arguments must have
default values. default values.
.. class:: PureType .. class:: Type
.. method:: filter(value, strict=False, allow_downcast=None) .. method:: filter(value, strict=False, allow_downcast=None)
...@@ -265,21 +265,21 @@ the Type is to instantiate a plain Type and set the needed fields: ...@@ -265,21 +265,21 @@ the Type is to instantiate a plain Type and set the needed fields:
.. testcode:: .. testcode::
from theano import gof from theano.gof.type import Type
double = gof.Type() double = Type()
double.filter = filter double.filter = filter
double.values_eq_approx = values_eq_approx double.values_eq_approx = values_eq_approx
Another way to make this Type is to make a subclass of ``gof.Type`` Another way to make this Type is to make a subclass of ``Type``
and define ``filter`` and ``values_eq_approx`` in the subclass: and define ``filter`` and ``values_eq_approx`` in the subclass:
.. code-block:: python .. code-block:: python
from theano import gof from theano.gof.type import Type
class Double(gof.Type): class Double(Type):
def filter(self, x, strict=False, allow_downcast=None): def filter(self, x, strict=False, allow_downcast=None):
# See code above. # See code above.
...@@ -300,9 +300,9 @@ instances of ``Double`` are technically the same Type. However, different ...@@ -300,9 +300,9 @@ instances of ``Double`` are technically the same Type. However, different
.. testsetup:: .. testsetup::
from theano import gof from theano.gof.type import Type
class Double(gof.Type): class Double(Type):
def filter(self, x, strict=False, allow_downcast=None): def filter(self, x, strict=False, allow_downcast=None):
if strict: if strict:
...@@ -399,9 +399,9 @@ Final version ...@@ -399,9 +399,9 @@ Final version
.. testcode:: .. testcode::
from theano import gof from theano.gof.type import Type
class Double(gof.Type): class Double(Type):
def filter(self, x, strict=False, allow_downcast=None): def filter(self, x, strict=False, allow_downcast=None):
if strict: if strict:
...@@ -432,4 +432,3 @@ Final version ...@@ -432,4 +432,3 @@ Final version
We add one utility function, ``__str__``. That way, when we print We add one utility function, ``__str__``. That way, when we print
``double``, it will print out something intelligible. ``double``, it will print out something intelligible.
...@@ -48,10 +48,10 @@ The first thing you need to do is to define a Theano Type for your ...@@ -48,10 +48,10 @@ The first thing you need to do is to define a Theano Type for your
params object. It doesn't have to be complete type because only the params object. It doesn't have to be complete type because only the
following methods will be used for the type: following methods will be used for the type:
- :meth:`filter <PureType.filter>` - :meth:`filter <Type.filter>`
- :meth:`__eq__ <PureType.__eq__>` - :meth:`__eq__ <Type.__eq__>`
- :meth:`__hash__ <PureType.__hash__>` - :meth:`__hash__ <Type.__hash__>`
- :meth:`values_eq <PureType.values_eq>` - :meth:`values_eq <Type.values_eq>`
Additionaly if you want to use your params with C code, you need to extend `COp` Additionaly if you want to use your params with C code, you need to extend `COp`
and implement the following methods: and implement the following methods:
......
...@@ -5,7 +5,7 @@ import theano ...@@ -5,7 +5,7 @@ import theano
from theano.gof import fg from theano.gof import fg
from theano.gof.graph import Apply, Constant, Variable from theano.gof.graph import Apply, Constant, Variable
from theano.gof.op import COp from theano.gof.op import COp
from theano.gof.type import Type from theano.gof.type import CType
from theano.link.basic import PerformLinker from theano.link.basic import PerformLinker
from theano.link.c.basic import CLinker, DualLinker, OpWiseCLinker from theano.link.c.basic import CLinker, DualLinker, OpWiseCLinker
...@@ -15,7 +15,7 @@ def as_variable(x): ...@@ -15,7 +15,7 @@ def as_variable(x):
return x return x
class TDouble(Type): class TDouble(CType):
def filter(self, data, strict=False, allow_downcast=False): def filter(self, data, strict=False, allow_downcast=False):
return float(data) return float(data)
......
...@@ -97,12 +97,12 @@ from theano.compile.function.types import FunctionMaker ...@@ -97,12 +97,12 @@ from theano.compile.function.types import FunctionMaker
from theano.gof import ( from theano.gof import (
Apply, Apply,
Constant, Constant,
CType,
FunctionGraph, FunctionGraph,
Generic, Generic,
InconsistencyError, InconsistencyError,
Op, Op,
OpenMPOp, OpenMPOp,
Type,
Variable, Variable,
generic, generic,
opt, opt,
......
...@@ -55,7 +55,7 @@ def cleanup(): ...@@ -55,7 +55,7 @@ def cleanup():
elif obj.startswith("c_compiler_str="): elif obj.startswith("c_compiler_str="):
have_c_compiler = True have_c_compiler = True
elif isinstance( elif isinstance(
obj, (theano.gof.Op, theano.gof.Type) obj, (theano.gof.Op, theano.gof.CType)
) and hasattr(obj, "c_code_cache_version"): ) and hasattr(obj, "c_code_cache_version"):
v = obj.c_code_cache_version() v = obj.c_code_cache_version()
if v not in [(), None] and v not in key[0]: if v not in [(), None] and v not in key[0]:
...@@ -139,7 +139,7 @@ def print_compiledir_content(): ...@@ -139,7 +139,7 @@ def print_compiledir_content():
{ {
x x
for x in flatten(keydata.keys) for x in flatten(keydata.keys)
if isinstance(x, theano.gof.Type) if isinstance(x, theano.gof.CType)
} }
) )
compile_start = compile_end = float("nan") compile_start = compile_end = float("nan")
......
...@@ -16,6 +16,7 @@ import theano ...@@ -16,6 +16,7 @@ import theano
from theano.gof import ParamsType from theano.gof import ParamsType
from theano.gof.graph import Apply, Variable from theano.gof.graph import Apply, Variable
from theano.gof.op import COp, Op from theano.gof.op import COp, Op
from theano.gof.type import CType
from theano.misc.safe_asarray import _asarray from theano.misc.safe_asarray import _asarray
...@@ -619,11 +620,11 @@ def as_op(itypes, otypes, infer_shape=None): ...@@ -619,11 +620,11 @@ def as_op(itypes, otypes, infer_shape=None):
""" """
if not isinstance(itypes, (list, tuple)): if not isinstance(itypes, (list, tuple)):
itypes = [itypes] itypes = [itypes]
if any(not isinstance(t, theano.Type) for t in itypes): if any(not isinstance(t, CType) for t in itypes):
raise TypeError("itypes has to be a list of Theano types") raise TypeError("itypes has to be a list of Theano types")
if not isinstance(otypes, (list, tuple)): if not isinstance(otypes, (list, tuple)):
otypes = [otypes] otypes = [otypes]
if any(not isinstance(t, theano.Type) for t in otypes): if any(not isinstance(t, CType) for t in otypes):
raise TypeError("otypes has to be a list of Theano types") raise TypeError("otypes has to be a list of Theano types")
# make sure they are lists and not tuples # make sure they are lists and not tuples
......
...@@ -42,5 +42,5 @@ from theano.gof.toolbox import ( ...@@ -42,5 +42,5 @@ from theano.gof.toolbox import (
ReplaceValidate, ReplaceValidate,
Validator, Validator,
) )
from theano.gof.type import CEnumType, EnumList, EnumType, Generic, Type, generic from theano.gof.type import CEnumType, CType, EnumList, EnumType, Generic, generic
from theano.gof.utils import MetaObject, MethodNotDefined from theano.gof.utils import MetaObject, MethodNotDefined
...@@ -223,27 +223,26 @@ class Apply(Node): ...@@ -223,27 +223,26 @@ class Apply(Node):
return cp return cp
def clone_with_new_inputs(self, inputs, strict=True): def clone_with_new_inputs(self, inputs, strict=True):
""" """Duplicate this `Apply` instance in a new graph.
Duplicate this Apply instance in a new graph.
Parameters Parameters
---------- ----------
inputs inputs : list of Variables
List of Variable instances to use as inputs. List of `Variable` instances to use as inputs.
strict : bool strict : bool
If True, the type fields of all the inputs must be equal If ``True``, the type fields of all the inputs must be equal
to the current ones (or compatible, for instance Tensor / to the current ones (or compatible, for instance `Tensor` /
GpuArray of the same dtype and broadcastable patterns, `GpuArray` of the same dtype and broadcastable patterns,
in which case they will be converted into current Type), and in which case they will be converted into current `Type`), and
returned outputs are guaranteed to have the same types as returned outputs are guaranteed to have the same types as
self.outputs. If False, then there's no guarantee that the ``self.outputs``. If ``False``, then there's no guarantee that the
clone's outputs will have the same types as self.outputs, clone's outputs will have the same types as ``self.outputs``,
and cloning may not even be possible (it depends on the Op). and cloning may not even be possible (it depends on the `Op`).
Returns Returns
------- -------
object object
An Apply instance with the same op but different outputs. An `Apply` instance with the same `Op` but different outputs.
""" """
assert isinstance(inputs, (list, tuple)) assert isinstance(inputs, (list, tuple))
...@@ -672,18 +671,18 @@ def walk( ...@@ -672,18 +671,18 @@ def walk(
Parameters Parameters
---------- ----------
nodes: deque nodes : deque
The nodes from which to start walking. The nodes from which to start walking.
expand: callable expand : callable
A callable that is applied to each node in `nodes`, the results of A callable that is applied to each node in `nodes`, the results of
which are either new nodes to visit or ``None``. which are either new nodes to visit or ``None``.
bfs: bool bfs : bool
If ``True``, breath first search is used; otherwise, depth first If ``True``, breath first search is used; otherwise, depth first
search. search.
return_children: bool return_children : bool
If ``True``, each output node will be accompanied by the output of If ``True``, each output node will be accompanied by the output of
`expand` (i.e. the corresponding child nodes). `expand` (i.e. the corresponding child nodes).
hash_fn: callable hash_fn : callable
The function used to produce hashes of the elements in `nodes`. The function used to produce hashes of the elements in `nodes`.
The default is ``id``. The default is ``id``.
...@@ -735,10 +734,10 @@ def ancestors( ...@@ -735,10 +734,10 @@ def ancestors(
Parameters Parameters
---------- ----------
graphs: list of `Variable` instances graphs : list of `Variable` instances
Output `Variable` instances from which to search backward through Output `Variable` instances from which to search backward through
owners. owners.
blockers: list of `Variable` instances blockers : list of `Variable` instances
A collection of `Variable`s that, when found, prevent the graph search A collection of `Variable`s that, when found, prevent the graph search
from preceding from that point. from preceding from that point.
...@@ -764,10 +763,10 @@ def graph_inputs( ...@@ -764,10 +763,10 @@ def graph_inputs(
Parameters Parameters
---------- ----------
graphs: list of `Variable` instances graphs : list of `Variable` instances
Output `Variable` instances from which to search backward through Output `Variable` instances from which to search backward through
owners. owners.
blockers: list of `Variable` instances blockers : list of `Variable` instances
A collection of `Variable`s that, when found, prevent the graph search A collection of `Variable`s that, when found, prevent the graph search
from preceding from that point. from preceding from that point.
...@@ -788,9 +787,9 @@ def vars_between( ...@@ -788,9 +787,9 @@ def vars_between(
Parameters Parameters
---------- ----------
ins: list ins : list
Input `Variable`s. Input `Variable`s.
outs: list outs : list
Output `Variable`s. Output `Variable`s.
Yields Yields
...@@ -817,9 +816,9 @@ def orphans_between( ...@@ -817,9 +816,9 @@ def orphans_between(
Parameters Parameters
---------- ----------
ins: list ins : list
Input `Variable`s. Input `Variable`s.
outs: list outs : list
Output `Variable`s. Output `Variable`s.
Yields Yields
...@@ -845,9 +844,9 @@ def applys_between( ...@@ -845,9 +844,9 @@ def applys_between(
Parameters Parameters
---------- ----------
ins: list ins : list
Input `Variable`s. Input `Variable`s.
outs: list outs : list
Output `Variable`s. Output `Variable`s.
Yields Yields
...@@ -972,15 +971,15 @@ def general_toposort( ...@@ -972,15 +971,15 @@ def general_toposort(
Parameters Parameters
---------- ----------
deps: callable deps : callable
A python function that takes a node as input and returns its dependence. A python function that takes a node as input and returns its dependence.
compute_deps_cache: optional compute_deps_cache : optional
If provided deps_cache should also be provided. This is a function like If provided deps_cache should also be provided. This is a function like
deps, but that also cache its results in a dict passed as deps_cache. deps, but that also cache its results in a dict passed as deps_cache.
deps_cache: dict deps_cache : dict
A dict mapping nodes to their children. This is populated by A dict mapping nodes to their children. This is populated by
`compute_deps_cache`. `compute_deps_cache`.
clients: dict clients : dict
If a dict is passed it will be filled with a mapping of If a dict is passed it will be filled with a mapping of
nodes-to-clients for each node in the subgraph. nodes-to-clients for each node in the subgraph.
...@@ -1357,9 +1356,9 @@ def list_of_nodes( ...@@ -1357,9 +1356,9 @@ def list_of_nodes(
Parameters Parameters
---------- ----------
inputs: list of Variable inputs : list of Variable
Input `Variable`s. Input `Variable`s.
outputs: list of Variable outputs : list of Variable
Output `Variable`s. Output `Variable`s.
""" """
...@@ -1380,9 +1379,9 @@ def is_in_ancestors(l_apply: Apply, f_node: Apply) -> bool: ...@@ -1380,9 +1379,9 @@ def is_in_ancestors(l_apply: Apply, f_node: Apply) -> bool:
Parameters Parameters
---------- ----------
l_apply: Apply l_apply : Apply
The node to walk. The node to walk.
f_apply: Apply f_apply : Apply
The node to find in `l_apply`. The node to find in `l_apply`.
Returns Returns
......
from theano.gof.type import Type from theano.gof.type import CType
class NullType(Type): class NullType(CType):
""" """
A type that allows no values. A type that allows no values.
......
...@@ -116,7 +116,7 @@ for more info about enumeration aliases). ...@@ -116,7 +116,7 @@ for more info about enumeration aliases).
import hashlib import hashlib
import re import re
from theano.gof.type import EnumType, Type from theano.gof.type import CType, EnumType
from theano.gof.utils import MethodNotDefined from theano.gof.utils import MethodNotDefined
...@@ -315,14 +315,15 @@ class Params(dict): ...@@ -315,14 +315,15 @@ class Params(dict):
return not self.__eq__(other) return not self.__eq__(other)
class ParamsType(Type): class ParamsType(CType):
""" """
This class can create a struct of Theano types (like TensorType, GpuArrayType, etc.) This class can create a struct of Theano types (like `TensorType`,
to be used as a convenience op parameter wrapping many data. `GpuArrayType`, etc.) to be used as a convenience op parameter wrapping
many data.
ParamsType constructor takes key-value args. `ParamsType` constructor takes key-value args. Key will be the name of the
Key will be the name of the attribute in the struct. attribute in the struct. Value is the Theano type of this attribute,
Value is the Theano type of this attribute, ie. an instance of (a subclass of) :class:`Type` ie. an instance of (a subclass of) :class:`CType`
(eg. ``TensorType('int64', (False,))``). (eg. ``TensorType('int64', (False,))``).
In a Python code any attribute named ``key`` will be available via:: In a Python code any attribute named ``key`` will be available via::
...@@ -337,7 +338,8 @@ class ParamsType(Type): ...@@ -337,7 +338,8 @@ class ParamsType(Type):
.. note:: .. note::
This Type is not complete and should never be used for regular graph operations. This `Type` is not complete and should never be used for regular graph
operations.
""" """
...@@ -358,9 +360,9 @@ class ParamsType(Type): ...@@ -358,9 +360,9 @@ class ParamsType(Type):
) )
type_instance = kwargs[attribute_name] type_instance = kwargs[attribute_name]
type_name = type_instance.__class__.__name__ type_name = type_instance.__class__.__name__
if not isinstance(type_instance, Type): if not isinstance(type_instance, CType):
raise TypeError( raise TypeError(
'ParamsType: attribute "%s" should inherit from Theano Type, got "%s".' 'ParamsType: attribute "%s" should inherit from Theano CType, got "%s".'
% (attribute_name, type_name) % (attribute_name, type_name)
) )
......
""" """The `Type` classes."""
WRITEME
Defines the `Type` class.
"""
import ctypes import ctypes
...@@ -16,12 +11,13 @@ from theano.gof import graph, utils ...@@ -16,12 +11,13 @@ from theano.gof import graph, utils
from theano.gof.op import COp from theano.gof.op import COp
from theano.gof.utils import MetaObject, MethodNotDefined from theano.gof.utils import MetaObject, MethodNotDefined
from theano.link.c.interface import CLinkerType from theano.link.c.interface import CLinkerType
from theano.utils import Singleton
__docformat__ = "restructuredtext en" __docformat__ = "restructuredtext en"
class PureType: class Type:
""" """
Interface specification for variable type instances. Interface specification for variable type instances.
...@@ -35,10 +31,10 @@ class PureType: ...@@ -35,10 +31,10 @@ class PureType:
""" """
# the type that will be created by call to make_variable. # the type that will be created by a call to make_variable.
Variable = graph.Variable Variable = graph.Variable
# the type that will be created by call to make_constant # the type that will be created by a call to make_constant
Constant = graph.Constant Constant = graph.Constant
def filter(self, data, strict=False, allow_downcast=None): def filter(self, data, strict=False, allow_downcast=None):
...@@ -217,13 +213,9 @@ class PureType: ...@@ -217,13 +213,9 @@ class PureType:
""" """
_nothing = """ class CType(MetaObject, Type, CLinkerType):
"""
class Type(MetaObject, PureType, CLinkerType):
""" """
Convenience wrapper combining `PureType` and `CLinkerType`. Convenience wrapper combining `Type` and `CLinkerType`.
Theano comes with several subclasses of such as: Theano comes with several subclasses of such as:
...@@ -265,48 +257,11 @@ class Type(MetaObject, PureType, CLinkerType): ...@@ -265,48 +257,11 @@ class Type(MetaObject, PureType, CLinkerType):
""" """
class SingletonType(Type): class Generic(CType, Singleton):
"""
Convenient Base class for a Type subclass with no attributes.
It saves having to implement __eq__ and __hash__.
"""
__instance = None
def __new__(cls):
# If sub-subclass of SingletonType don't redeclare __instance
# when we look for it, we will find it in the subclass. We
# don't want that, so we check the class. When we add one, we
# add one only to the current class, so all is working
# correctly.
if cls.__instance is None or not isinstance(cls.__instance, cls):
cls.__instance = Type.__new__(cls)
return cls.__instance
def __str__(self):
return self.__class__.__name__
# even if we try to make a singleton, this do not always work. So
# we compare the type. See test_type_other.test_none_Constant for
# an exmple. So we need to implement __eq__ and __hash__
def __eq__(self, other):
if self is other:
return True
if type(self) is type(other):
return True
return False
def __hash__(self):
return hash(type(self))
class Generic(SingletonType):
""" """
Represents a generic Python object. Represents a generic Python object.
This class implements the `PureType` and `CLinkerType` interfaces This class implements the `CType` and `CLinkerType` interfaces
for generic PyObject instances. for generic PyObject instances.
EXAMPLE of what this means, or when you would use this type. EXAMPLE of what this means, or when you would use this type.
...@@ -400,7 +355,7 @@ class _make_cdata(COp): ...@@ -400,7 +355,7 @@ class _make_cdata(COp):
return (0, self.rtype.version) return (0, self.rtype.version)
class CDataType(Type): class CDataType(CType):
""" """
Represents opaque C data to be passed around. The intent is to Represents opaque C data to be passed around. The intent is to
ease passing arbitrary data between ops C code. ease passing arbitrary data between ops C code.
...@@ -613,7 +568,7 @@ class CDataTypeConstant(graph.Constant): ...@@ -613,7 +568,7 @@ class CDataTypeConstant(graph.Constant):
CDataType.Constant = CDataTypeConstant CDataType.Constant = CDataTypeConstant
class EnumType(Type, dict): class EnumType(CType, dict):
""" """
Main subclasses: Main subclasses:
- :class:`EnumList` - :class:`EnumList`
...@@ -804,12 +759,12 @@ class EnumType(Type, dict): ...@@ -804,12 +759,12 @@ class EnumType(Type, dict):
def __getattr__(self, key): def __getattr__(self, key):
if key in self: if key in self:
return self[key] return self[key]
return Type.__getattr__(self, key) return CType.__getattr__(self, key)
def __setattr__(self, key, value): def __setattr__(self, key, value):
if key in self: if key in self:
raise NotImplementedError("constant values are immutable.") raise NotImplementedError("constant values are immutable.")
Type.__setattr__(self, key, value) CType.__setattr__(self, key, value)
def __setitem__(self, key, value): def __setitem__(self, key, value):
raise NotImplementedError("constant values are immutable.") raise NotImplementedError("constant values are immutable.")
......
...@@ -12,7 +12,7 @@ from theano.gof.graph import Apply, Variable ...@@ -12,7 +12,7 @@ from theano.gof.graph import Apply, Variable
from theano.gof.op import COp, ExternalCOp, Op from theano.gof.op import COp, ExternalCOp, Op
from theano.gof.opt import copy_stack_trace from theano.gof.opt import copy_stack_trace
from theano.gof.params_type import ParamsType from theano.gof.params_type import ParamsType
from theano.gof.type import Type from theano.gof.type import CType
from theano.gof.utils import MethodNotDefined from theano.gof.utils import MethodNotDefined
from theano.gradient import grad_undefined from theano.gradient import grad_undefined
from theano.link.c.interface import HideC from theano.link.c.interface import HideC
...@@ -220,7 +220,7 @@ class Kernel: ...@@ -220,7 +220,7 @@ class Kernel:
def get_dtype(t): def get_dtype(t):
if isinstance(t, str): if isinstance(t, str):
return np.dtype(t) return np.dtype(t)
elif isinstance(t, Type): elif isinstance(t, CType):
return t.dtype return t.dtype
elif isinstance(t, Variable): elif isinstance(t, Variable):
return t.type.dtype return t.type.dtype
......
...@@ -6,6 +6,7 @@ import theano.tensor as tt ...@@ -6,6 +6,7 @@ import theano.tensor as tt
from theano import gof from theano import gof
from theano.gof.op import COp, Op from theano.gof.op import COp, Op
from theano.gof.params_type import ParamsType from theano.gof.params_type import ParamsType
from theano.gof.type import CType
from theano.gradient import grad_not_implemented from theano.gradient import grad_not_implemented
from theano.link.c.interface import HideC from theano.link.c.interface import HideC
from theano.scalar import bool as bool_t from theano.scalar import bool as bool_t
...@@ -160,7 +161,7 @@ class GpuSubtensor(HideC, Subtensor): ...@@ -160,7 +161,7 @@ class GpuSubtensor(HideC, Subtensor):
return "0", 1 return "0", 1
elif isinstance(idx, (np.integer, int)): elif isinstance(idx, (np.integer, int)):
return str(idx), 0 return str(idx), 0
elif isinstance(idx, gof.Type): elif isinstance(idx, CType):
return indices.pop(0), 0 return indices.pop(0), 0
else: else:
assert 0, idx assert 0, idx
...@@ -195,7 +196,7 @@ class GpuSubtensor(HideC, Subtensor): ...@@ -195,7 +196,7 @@ class GpuSubtensor(HideC, Subtensor):
file=sio, file=sio,
) )
else: else:
if isinstance(idx, gof.Type): if isinstance(idx, CType):
start = indices.pop(0) start = indices.pop(0)
elif isinstance(idx, (np.integer, int)): elif isinstance(idx, (np.integer, int)):
start = idx start = idx
...@@ -263,7 +264,7 @@ class GpuIncSubtensor(IncSubtensor): ...@@ -263,7 +264,7 @@ class GpuIncSubtensor(IncSubtensor):
indices = list(reversed(inputs[2:])) indices = list(reversed(inputs[2:]))
def convert(entry): def convert(entry):
if isinstance(entry, gof.Type): if isinstance(entry, CType):
rval = indices.pop() rval = indices.pop()
return rval return rval
elif isinstance(entry, slice): elif isinstance(entry, slice):
......
...@@ -6,7 +6,7 @@ import warnings ...@@ -6,7 +6,7 @@ import warnings
import numpy as np import numpy as np
import theano import theano
from theano import Constant, Type, Variable, config, scalar, tensor from theano import Constant, CType, Variable, config, scalar, tensor
from theano.compile import SharedVariable from theano.compile import SharedVariable
from theano.misc.safe_asarray import _asarray from theano.misc.safe_asarray import _asarray
from theano.tensor.type import TensorType from theano.tensor.type import TensorType
...@@ -127,7 +127,7 @@ def _unreg_context(name): ...@@ -127,7 +127,7 @@ def _unreg_context(name):
del _context_reg[name] del _context_reg[name]
class GpuArrayType(Type): class GpuArrayType(CType):
""" """
The type that represents an array on a gpu. The type that represents an array on a gpu.
...@@ -173,7 +173,7 @@ class GpuArrayType(Type): ...@@ -173,7 +173,7 @@ class GpuArrayType(Type):
See Also See Also
-------- --------
theano.gof.type.PureType theano.gof.type.Type
""" """
...@@ -883,7 +883,7 @@ theano.compile.register_specify_shape_c_code( ...@@ -883,7 +883,7 @@ theano.compile.register_specify_shape_c_code(
) )
class GpuContextType(Type): class GpuContextType(CType):
""" """
Minimal type used for passing contexts to nodes. Minimal type used for passing contexts to nodes.
......
...@@ -119,7 +119,7 @@ def grad_undefined(op, x_pos, x, comment=""): ...@@ -119,7 +119,7 @@ def grad_undefined(op, x_pos, x, comment=""):
)() )()
class DisconnectedType(theano.gof.type.Type): class DisconnectedType(theano.gof.type.CType):
"""A type indicating that a variable is a result """A type indicating that a variable is a result
of taking the gradient of c with respect to x of taking the gradient of c with respect to x
......
...@@ -4,7 +4,7 @@ from copy import copy, deepcopy ...@@ -4,7 +4,7 @@ from copy import copy, deepcopy
from theano.configdefaults import config from theano.configdefaults import config
from theano.gof.fg import FunctionGraph from theano.gof.fg import FunctionGraph
from theano.gof.graph import Apply from theano.gof.graph import Apply
from theano.gof.type import Type from theano.gof.type import CType
from theano.link.utils import gc_helper, map_storage, raise_with_op, streamline from theano.link.utils import gc_helper, map_storage, raise_with_op, streamline
from theano.utils import deprecated, difference, to_return_values from theano.utils import deprecated, difference, to_return_values
...@@ -45,7 +45,7 @@ class Container: ...@@ -45,7 +45,7 @@ class Container:
): ):
if not isinstance(storage, list) or not len(storage) >= 1: if not isinstance(storage, list) or not len(storage) >= 1:
raise TypeError("storage must be a list of length at least one") raise TypeError("storage must be a list of length at least one")
if isinstance(r, Type): if isinstance(r, CType):
self.type = r self.type = r
else: else:
self.type = r.type self.type = r.type
......
...@@ -6,6 +6,7 @@ import logging ...@@ -6,6 +6,7 @@ import logging
import os import os
import sys import sys
from collections import defaultdict from collections import defaultdict
from contextlib import suppress
from copy import copy from copy import copy
from io import StringIO from io import StringIO
...@@ -25,6 +26,7 @@ from theano.link.c.cmodule import ( ...@@ -25,6 +26,7 @@ from theano.link.c.cmodule import (
dlimport_workdir, dlimport_workdir,
) )
from theano.link.c.cmodule import get_module_cache as _get_module_cache from theano.link.c.cmodule import get_module_cache as _get_module_cache
from theano.link.c.interface import CLinkerObject, CLinkerOp, CLinkerType
from theano.link.utils import gc_helper, map_storage, raise_with_op, streamline from theano.link.utils import gc_helper, map_storage, raise_with_op, streamline
from theano.utils import difference, uniq from theano.utils import difference, uniq
...@@ -667,13 +669,13 @@ class CLinker(Linker): ...@@ -667,13 +669,13 @@ class CLinker(Linker):
self.consts = [] self.consts = []
# Move c type from orphans (theano.scalar.Scalar) to self.consts # Move c type from orphans (theano.scalar.Scalar) to self.consts
for variable in self.orphans: for variable in self.orphans:
if isinstance(variable, Constant): if isinstance(variable, Constant) and isinstance(
try: variable.type, CLinkerType
):
with suppress(MethodNotDefined, NotImplementedError):
variable.type.c_literal(variable.data) variable.type.c_literal(variable.data)
self.consts.append(variable) self.consts.append(variable)
self.orphans.remove(variable) self.orphans.remove(variable)
except (MethodNotDefined, NotImplementedError):
pass
self.temps = list( self.temps = list(
set(self.variables) set(self.variables)
...@@ -721,6 +723,10 @@ class CLinker(Linker): ...@@ -721,6 +723,10 @@ class CLinker(Linker):
id = 1 id = 1
for variable in self.variables: for variable in self.variables:
if not isinstance(variable.type, CLinkerType):
raise NotImplementedError(f"Type of {variable} cannot produce C code")
sub = dict(failure_var=failure_var) sub = dict(failure_var=failure_var)
# it might be possible to inline constant variables as C literals # it might be possible to inline constant variables as C literals
...@@ -816,6 +822,11 @@ class CLinker(Linker): ...@@ -816,6 +822,11 @@ class CLinker(Linker):
for node_num, node in enumerate(self.node_order): for node_num, node in enumerate(self.node_order):
op = node.op
if not isinstance(op, CLinkerOp):
raise NotImplementedError(f"{op} cannot produce C code")
sub = dict(failure_var=failure_var) sub = dict(failure_var=failure_var)
params = node.run_params() params = node.run_params()
...@@ -849,56 +860,43 @@ class CLinker(Linker): ...@@ -849,56 +860,43 @@ class CLinker(Linker):
struct_init = "" struct_init = ""
struct_cleanup = "" struct_cleanup = ""
op = node.op with suppress(MethodNotDefined):
# type-specific support code
try:
c_support_code_apply.append(op.c_support_code_apply(node, name)) c_support_code_apply.append(op.c_support_code_apply(node, name))
except MethodNotDefined:
pass
else:
# The following will be executed if the "try" block succeeds
assert isinstance(c_support_code_apply[-1], str), ( assert isinstance(c_support_code_apply[-1], str), (
str(node.op) + " didn't return a string for c_support_code_apply" str(node.op) + " didn't return a string for c_support_code_apply"
) )
try: with suppress(MethodNotDefined):
c_init_code_apply.append(op.c_init_code_apply(node, name)) c_init_code_apply.append(op.c_init_code_apply(node, name))
except MethodNotDefined:
pass
else:
assert isinstance(c_init_code_apply[-1], str), ( assert isinstance(c_init_code_apply[-1], str), (
str(node.op) + " didn't return a string for c_init_code_apply" str(node.op) + " didn't return a string for c_init_code_apply"
) )
try: with suppress(MethodNotDefined):
struct_init = op.c_init_code_struct(node, name, sub_struct) struct_init = op.c_init_code_struct(node, name, sub_struct)
assert isinstance(struct_init, str), ( assert isinstance(struct_init, str), (
str(node.op) + " didn't return a string for c_init_code_struct" str(node.op) + " didn't return a string for c_init_code_struct"
) )
except MethodNotDefined:
pass
try: with suppress(MethodNotDefined):
struct_support = op.c_support_code_struct(node, name) struct_support = op.c_support_code_struct(node, name)
assert isinstance(struct_support, str), ( assert isinstance(struct_support, str), (
str(node.op) + " didn't return a string for c_support_code_struct" str(node.op) + " didn't return a string for c_support_code_struct"
) )
except MethodNotDefined:
pass
try: with suppress(MethodNotDefined):
struct_cleanup = op.c_cleanup_code_struct(node, name) struct_cleanup = op.c_cleanup_code_struct(node, name)
assert isinstance(struct_cleanup, str), ( assert isinstance(struct_cleanup, str), (
str(node.op) + " didn't return a string for c_cleanup_code_struct" str(node.op) + " didn't return a string for c_cleanup_code_struct"
) )
except MethodNotDefined:
pass
# emit c_code # emit c_code
try: try:
behavior = op.c_code(node, name, isyms, osyms, sub) behavior = op.c_code(node, name, isyms, osyms, sub)
except MethodNotDefined: except MethodNotDefined:
raise NotImplementedError(f"{op} cannot produce C code") raise NotImplementedError(f"{op} cannot produce C code")
assert isinstance( assert isinstance(
behavior, str behavior, str
), f"{node.op} didn't return a string for c_code" ), f"{node.op} didn't return a string for c_code"
...@@ -987,14 +985,12 @@ class CLinker(Linker): ...@@ -987,14 +985,12 @@ class CLinker(Linker):
) )
# generic support code # generic support code
for x in [y.type for y in self.variables] + [y.op for y in self.node_order]: for x in [y.type for y in self.variables] + [y.op for y in self.node_order]:
try: with suppress(MethodNotDefined):
support_code = x.c_support_code() support_code = x.c_support_code()
if isinstance(support_code, list): if isinstance(support_code, list):
ret.extend(support_code) ret.extend(support_code)
else: else:
ret.append(support_code) ret.append(support_code)
except MethodNotDefined:
pass
return ret return ret
def compile_args(self): def compile_args(self):
...@@ -1026,20 +1022,20 @@ class CLinker(Linker): ...@@ -1026,20 +1022,20 @@ class CLinker(Linker):
c_compiler = self.c_compiler() c_compiler = self.c_compiler()
for x in [y.type for y in self.variables] + [y.op for y in self.node_order]: for x in [y.type for y in self.variables] + [y.op for y in self.node_order]:
try: if isinstance(x, CLinkerObject):
with suppress(MethodNotDefined):
try: try:
ret += x.c_compile_args(c_compiler) ret += x.c_compile_args(c_compiler)
except TypeError: except TypeError:
ret += x.c_compile_args() ret += x.c_compile_args()
except MethodNotDefined:
pass
ret = uniq(ret) # to remove duplicate ret = uniq(ret) # to remove duplicate
# The args set by the compiler include the user flags. We do not want # The args set by the compiler include the user flags. We do not want
# to reorder them # to reorder them
ret += c_compiler.compile_args() ret += c_compiler.compile_args()
for x in [y.type for y in self.variables] + [y.op for y in self.node_order]: for x in [y.type for y in self.variables] + [y.op for y in self.node_order]:
try: if isinstance(x, CLinkerObject):
with suppress(MethodNotDefined):
try: try:
no_comp = x.c_no_compile_args(c_compiler) no_comp = x.c_no_compile_args(c_compiler)
except TypeError: except TypeError:
...@@ -1049,8 +1045,6 @@ class CLinker(Linker): ...@@ -1049,8 +1045,6 @@ class CLinker(Linker):
ret.remove(i) ret.remove(i)
except ValueError: except ValueError:
pass # in case the value is not there pass # in case the value is not there
except MethodNotDefined:
pass
return ret return ret
def headers(self): def headers(self):
...@@ -1064,13 +1058,12 @@ class CLinker(Linker): ...@@ -1064,13 +1058,12 @@ class CLinker(Linker):
ret = [] ret = []
c_compiler = self.c_compiler() c_compiler = self.c_compiler()
for x in [y.type for y in self.variables] + [y.op for y in self.node_order]: for x in [y.type for y in self.variables] + [y.op for y in self.node_order]:
try: if isinstance(x, CLinkerObject):
with suppress(MethodNotDefined):
try: try:
ret += x.c_headers(c_compiler) ret += x.c_headers(c_compiler)
except TypeError: except TypeError:
ret += x.c_headers() ret += x.c_headers()
except MethodNotDefined:
pass
return uniq(ret) return uniq(ret)
def init_code(self): def init_code(self):
...@@ -1083,15 +1076,15 @@ class CLinker(Linker): ...@@ -1083,15 +1076,15 @@ class CLinker(Linker):
""" """
ret = [] ret = []
for x in [y.type for y in self.variables] + [y.op for y in self.node_order]: for x in [y.type for y in self.variables] + [y.op for y in self.node_order]:
try: if isinstance(x, CLinkerObject):
with suppress(MethodNotDefined):
ret += x.c_init_code() ret += x.c_init_code()
except MethodNotDefined:
pass
return uniq(ret) return uniq(ret)
def c_compiler(self): def c_compiler(self):
c_compiler = None c_compiler = None
for x in [y.type for y in self.variables] + [y.op for y in self.node_order]: for x in [y.type for y in self.variables] + [y.op for y in self.node_order]:
# FIXME: Why would a `Type` have a `c_compiler` field?!
if hasattr(x, "c_compiler"): if hasattr(x, "c_compiler"):
x_compiler = x.c_compiler() x_compiler = x.c_compiler()
else: else:
...@@ -1121,13 +1114,12 @@ class CLinker(Linker): ...@@ -1121,13 +1114,12 @@ class CLinker(Linker):
ret = [] ret = []
c_compiler = self.c_compiler() c_compiler = self.c_compiler()
for x in [y.type for y in self.variables] + [y.op for y in self.node_order]: for x in [y.type for y in self.variables] + [y.op for y in self.node_order]:
try: if isinstance(x, CLinkerObject):
with suppress(MethodNotDefined):
try: try:
ret += x.c_header_dirs(c_compiler) ret += x.c_header_dirs(c_compiler)
except TypeError: except TypeError:
ret += x.c_header_dirs() ret += x.c_header_dirs()
except MethodNotDefined:
pass
# filter out empty strings/None # filter out empty strings/None
return [r for r in uniq(ret) if r] return [r for r in uniq(ret) if r]
...@@ -1142,13 +1134,12 @@ class CLinker(Linker): ...@@ -1142,13 +1134,12 @@ class CLinker(Linker):
ret = [] ret = []
c_compiler = self.c_compiler() c_compiler = self.c_compiler()
for x in [y.type for y in self.variables] + [y.op for y in self.node_order]: for x in [y.type for y in self.variables] + [y.op for y in self.node_order]:
try: if isinstance(x, CLinkerObject):
with suppress(MethodNotDefined):
try: try:
ret += x.c_libraries(c_compiler) ret += x.c_libraries(c_compiler)
except TypeError: except TypeError:
ret += x.c_libraries() ret += x.c_libraries()
except MethodNotDefined:
pass
return uniq(ret) return uniq(ret)
def lib_dirs(self): def lib_dirs(self):
...@@ -1162,13 +1153,12 @@ class CLinker(Linker): ...@@ -1162,13 +1153,12 @@ class CLinker(Linker):
ret = [] ret = []
c_compiler = self.c_compiler() c_compiler = self.c_compiler()
for x in [y.type for y in self.variables] + [y.op for y in self.node_order]: for x in [y.type for y in self.variables] + [y.op for y in self.node_order]:
try: if isinstance(x, CLinkerObject):
with suppress(MethodNotDefined):
try: try:
ret += x.c_lib_dirs(c_compiler) ret += x.c_lib_dirs(c_compiler)
except TypeError: except TypeError:
ret += x.c_lib_dirs() ret += x.c_lib_dirs()
except MethodNotDefined:
pass
# filter out empty strings/None # filter out empty strings/None
return [r for r in uniq(ret) if r] return [r for r in uniq(ret) if r]
...@@ -1542,8 +1532,10 @@ class CLinker(Linker): ...@@ -1542,8 +1532,10 @@ class CLinker(Linker):
if hasattr(node.op, "__props__"): if hasattr(node.op, "__props__"):
version.append(node.op.__props__) version.append(node.op.__props__)
for i in node.inputs: for i in node.inputs:
if isinstance(i.type, CLinkerObject):
version.append(i.type.c_code_cache_version()) version.append(i.type.c_code_cache_version())
for o in node.outputs: for o in node.outputs:
if isinstance(o.type, CLinkerObject):
version.append(o.type.c_code_cache_version()) version.append(o.type.c_code_cache_version())
# add the signature for this node # add the signature for this node
......
...@@ -6,7 +6,6 @@ import jax ...@@ -6,7 +6,6 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
import jax.scipy as jsp import jax.scipy as jsp
import theano
from theano.compile.ops import ( from theano.compile.ops import (
DeepCopyOp, DeepCopyOp,
Rebroadcast, Rebroadcast,
...@@ -17,6 +16,7 @@ from theano.compile.ops import ( ...@@ -17,6 +16,7 @@ from theano.compile.ops import (
) )
from theano.configdefaults import config from theano.configdefaults import config
from theano.gof import FunctionGraph from theano.gof import FunctionGraph
from theano.gof.type import CType
from theano.ifelse import IfElse from theano.ifelse import IfElse
from theano.scalar.basic import Cast, Clip, Composite, Identity, ScalarOp, Second from theano.scalar.basic import Cast, Clip, Composite, Identity, ScalarOp, Second
from theano.scan.op import Scan from theano.scan.op import Scan
...@@ -589,7 +589,7 @@ def jax_funcify_IfElse(op): ...@@ -589,7 +589,7 @@ def jax_funcify_IfElse(op):
def convert_indices(indices, entry): def convert_indices(indices, entry):
if indices and isinstance(entry, theano.gof.Type): if indices and isinstance(entry, CType):
rval = indices.pop(0) rval = indices.pop(0)
return rval return rval
elif isinstance(entry, slice): elif isinstance(entry, slice):
......
...@@ -25,7 +25,7 @@ from theano.gof.fg import FunctionGraph ...@@ -25,7 +25,7 @@ from theano.gof.fg import FunctionGraph
from theano.gof.graph import Apply, Constant, Variable, clone, list_of_nodes from theano.gof.graph import Apply, Constant, Variable, clone, list_of_nodes
from theano.gof.op import COp from theano.gof.op import COp
from theano.gof.opt import MergeOptimizer from theano.gof.opt import MergeOptimizer
from theano.gof.type import Type from theano.gof.type import CType
from theano.gof.utils import MetaObject, MethodNotDefined from theano.gof.utils import MetaObject, MethodNotDefined
from theano.gradient import DisconnectedType, grad_undefined from theano.gradient import DisconnectedType, grad_undefined
from theano.misc.safe_asarray import _asarray from theano.misc.safe_asarray import _asarray
...@@ -313,7 +313,7 @@ def constant(x, name=None, dtype=None): ...@@ -313,7 +313,7 @@ def constant(x, name=None, dtype=None):
return ScalarConstant(get_scalar_type(str(x.dtype)), x, name=name) return ScalarConstant(get_scalar_type(str(x.dtype)), x, name=name)
class Scalar(Type): class Scalar(CType):
""" """
Internal class, should not be used by clients. Internal class, should not be used by clients.
...@@ -1096,7 +1096,7 @@ class ScalarOp(COp): ...@@ -1096,7 +1096,7 @@ class ScalarOp(COp):
if hasattr(self, "output_types_preference"): if hasattr(self, "output_types_preference"):
variables = self.output_types_preference(*types) variables = self.output_types_preference(*types)
if not isinstance(variables, (list, tuple)) or any( if not isinstance(variables, (list, tuple)) or any(
not isinstance(x, Type) for x in variables not isinstance(x, CType) for x in variables
): ):
raise TypeError( raise TypeError(
"output_types_preference should return a list or a tuple of types", "output_types_preference should return a list or a tuple of types",
......
...@@ -32,7 +32,7 @@ def _is_sparse(x): ...@@ -32,7 +32,7 @@ def _is_sparse(x):
return isinstance(x, scipy.sparse.spmatrix) return isinstance(x, scipy.sparse.spmatrix)
class SparseType(gof.Type): class SparseType(gof.CType):
""" """
Fundamental way to create a sparse node. Fundamental way to create a sparse node.
......
...@@ -565,7 +565,7 @@ def get_scalar_constant_value( ...@@ -565,7 +565,7 @@ def get_scalar_constant_value(
var.ndim == 0 for var in v.owner.inputs[0].owner.inputs[1:] var.ndim == 0 for var in v.owner.inputs[0].owner.inputs[1:]
): ):
idx = v.owner.op.idx_list[0] idx = v.owner.op.idx_list[0]
if isinstance(idx, gof.Type): if isinstance(idx, gof.CType):
idx = get_scalar_constant_value( idx = get_scalar_constant_value(
v.owner.inputs[1], max_recur=max_recur v.owner.inputs[1], max_recur=max_recur
) )
...@@ -579,7 +579,7 @@ def get_scalar_constant_value( ...@@ -579,7 +579,7 @@ def get_scalar_constant_value(
var.ndim == 1 for var in v.owner.inputs[0].owner.inputs[1:] var.ndim == 1 for var in v.owner.inputs[0].owner.inputs[1:]
): ):
idx = v.owner.op.idx_list[0] idx = v.owner.op.idx_list[0]
if isinstance(idx, gof.Type): if isinstance(idx, gof.CType):
idx = get_scalar_constant_value( idx = get_scalar_constant_value(
v.owner.inputs[1], max_recur=max_recur v.owner.inputs[1], max_recur=max_recur
) )
...@@ -616,7 +616,7 @@ def get_scalar_constant_value( ...@@ -616,7 +616,7 @@ def get_scalar_constant_value(
): ):
idx = v.owner.op.idx_list[0] idx = v.owner.op.idx_list[0]
if isinstance(idx, gof.Type): if isinstance(idx, gof.CType):
idx = get_scalar_constant_value( idx = get_scalar_constant_value(
v.owner.inputs[1], max_recur=max_recur v.owner.inputs[1], max_recur=max_recur
) )
...@@ -638,7 +638,7 @@ def get_scalar_constant_value( ...@@ -638,7 +638,7 @@ def get_scalar_constant_value(
op = owner.op op = owner.op
idx_list = op.idx_list idx_list = op.idx_list
idx = idx_list[0] idx = idx_list[0]
if isinstance(idx, gof.Type): if isinstance(idx, gof.CType):
idx = get_scalar_constant_value( idx = get_scalar_constant_value(
owner.inputs[1], max_recur=max_recur owner.inputs[1], max_recur=max_recur
) )
......
...@@ -3,10 +3,10 @@ import sys ...@@ -3,10 +3,10 @@ import sys
import numpy as np import numpy as np
import theano import theano
from theano.gof.type import Type from theano.gof.type import CType
class RandomStateType(Type): class RandomStateType(CType):
"""A Type wrapper for `numpy.random.RandomState`. """A Type wrapper for `numpy.random.RandomState`.
The reason this exists (and `Generic` doesn't suffice) is that The reason this exists (and `Generic` doesn't suffice) is that
......
...@@ -7,13 +7,13 @@ from textwrap import dedent ...@@ -7,13 +7,13 @@ from textwrap import dedent
import numpy as np import numpy as np
import theano import theano
from theano import gof
from theano import scalar as scal from theano import scalar as scal
from theano.configdefaults import config from theano.configdefaults import config
from theano.gof import MethodNotDefined, ParamsType from theano.gof.graph import Apply, Variable
from theano.gof.graph import Apply
from theano.gof.op import COp, Op from theano.gof.op import COp, Op
from theano.gof.type import Type from theano.gof.params_type import ParamsType
from theano.gof.type import CType
from theano.gof.utils import MethodNotDefined
from theano.gradient import DisconnectedType from theano.gradient import DisconnectedType
from theano.misc.safe_asarray import _asarray from theano.misc.safe_asarray import _asarray
from theano.printing import pprint from theano.printing import pprint
...@@ -56,7 +56,7 @@ def as_index_constant(a): ...@@ -56,7 +56,7 @@ def as_index_constant(a):
) )
elif isinstance(a, (int, np.integer)): elif isinstance(a, (int, np.integer)):
return scal.ScalarConstant(scal.int64, a) return scal.ScalarConstant(scal.int64, a)
elif not isinstance(a, theano.tensor.Variable): elif not isinstance(a, Variable):
return theano.tensor.as_tensor(a) return theano.tensor.as_tensor(a)
else: else:
return a return a
...@@ -82,7 +82,7 @@ def get_idx_list(inputs, idx_list, get_count=False): ...@@ -82,7 +82,7 @@ def get_idx_list(inputs, idx_list, get_count=False):
# General case # General case
def convert(entry): def convert(entry):
if isinstance(entry, gof.Type): if isinstance(entry, CType):
return indices.pop() return indices.pop()
elif isinstance(entry, slice): elif isinstance(entry, slice):
return slice(convert(entry.start), convert(entry.stop), convert(entry.step)) return slice(convert(entry.start), convert(entry.stop), convert(entry.step))
...@@ -115,7 +115,7 @@ def get_canonical_form_slice(theslice, length): ...@@ -115,7 +115,7 @@ def get_canonical_form_slice(theslice, length):
try: try:
x_constant = get_scalar_constant_value(x) x_constant = get_scalar_constant_value(x)
is_constant = True is_constant = True
except theano.tensor.NotScalarConstantError: except NotScalarConstantError:
x_constant = theano.tensor.extract_constant(x) x_constant = theano.tensor.extract_constant(x)
is_constant = False is_constant = False
return x_constant, is_constant return x_constant, is_constant
...@@ -487,30 +487,30 @@ class Subtensor(COp): ...@@ -487,30 +487,30 @@ class Subtensor(COp):
) )
if ( if (
isinstance(entry, (np.ndarray, theano.tensor.Variable)) isinstance(entry, (np.ndarray, Variable))
and hasattr(entry, "dtype") and hasattr(entry, "dtype")
and entry.dtype == "bool" and entry.dtype == "bool"
): ):
raise AdvancedIndexingError("Invalid index type or slice for Subtensor") raise AdvancedIndexingError("Invalid index type or slice for Subtensor")
if isinstance(entry, gof.Variable) and ( if isinstance(entry, Variable) and (
entry.type in invalid_scal_types or entry.type in invalid_tensor_types entry.type in invalid_scal_types or entry.type in invalid_tensor_types
): ):
raise TypeError("Expected an integer") raise TypeError("Expected an integer")
if isinstance(entry, gof.Variable) and entry.type in scal_types: if isinstance(entry, Variable) and entry.type in scal_types:
return entry.type return entry.type
elif isinstance(entry, gof.Type) and entry in scal_types: elif isinstance(entry, CType) and entry in scal_types:
return entry return entry
if ( if (
isinstance(entry, gof.Variable) isinstance(entry, Variable)
and entry.type in tensor_types and entry.type in tensor_types
and np.all(entry.type.broadcastable) and np.all(entry.type.broadcastable)
): ):
return scal.get_scalar_type(entry.type.dtype) return scal.get_scalar_type(entry.type.dtype)
elif ( elif (
isinstance(entry, gof.Type) isinstance(entry, CType)
and entry in tensor_types and entry in tensor_types
and np.all(entry.broadcastable) and np.all(entry.broadcastable)
): ):
...@@ -553,7 +553,7 @@ class Subtensor(COp): ...@@ -553,7 +553,7 @@ class Subtensor(COp):
""" """
Return the idx_list with constant inputs replaced by their Return the idx_list with constant inputs replaced by their
python scalar equivalent. python scalar equivalent.
May raise `theano.tensor.NotScalarConstantError` if the idx contains May raise `NotScalarConstantError` if the idx contains
non-constant entries. non-constant entries.
If allow_partial is True, then entries that are not constant will If allow_partial is True, then entries that are not constant will
...@@ -594,7 +594,7 @@ class Subtensor(COp): ...@@ -594,7 +594,7 @@ class Subtensor(COp):
only_process_constants=only_process_constants, only_process_constants=only_process_constants,
elemwise=elemwise, elemwise=elemwise,
) )
except theano.tensor.NotScalarConstantError: except NotScalarConstantError:
if allow_partial: if allow_partial:
return val return val
else: else:
...@@ -610,7 +610,7 @@ class Subtensor(COp): ...@@ -610,7 +610,7 @@ class Subtensor(COp):
# Since scal.as_scalar does not know about tensor types (it would # Since scal.as_scalar does not know about tensor types (it would
# create a circular import) , this method converts either a # create a circular import) , this method converts either a
# TensorVariable or a ScalarVariable to a scalar. # TensorVariable or a ScalarVariable to a scalar.
if isinstance(a, gof.Variable) and isinstance(a.type, TensorType): if isinstance(a, Variable) and isinstance(a.type, TensorType):
return theano.tensor.scalar_from_tensor(a) return theano.tensor.scalar_from_tensor(a)
else: else:
return scal.as_scalar(a) return scal.as_scalar(a)
...@@ -633,7 +633,7 @@ class Subtensor(COp): ...@@ -633,7 +633,7 @@ class Subtensor(COp):
raise IndexError("too many indices for array") raise IndexError("too many indices for array")
input_types = Subtensor.collapse( input_types = Subtensor.collapse(
idx_list, lambda entry: isinstance(entry, gof.Type) idx_list, lambda entry: isinstance(entry, CType)
) )
if len(inputs) != len(input_types): if len(inputs) != len(input_types):
raise IndexError( raise IndexError(
...@@ -672,7 +672,7 @@ class Subtensor(COp): ...@@ -672,7 +672,7 @@ class Subtensor(COp):
broadcastable.append(False) broadcastable.append(False)
return gof.Apply( return Apply(
self, self,
(x,) + inputs, (x,) + inputs,
[theano.tensor.tensor(dtype=x.type.dtype, broadcastable=broadcastable)], [theano.tensor.tensor(dtype=x.type.dtype, broadcastable=broadcastable)],
...@@ -851,7 +851,7 @@ class Subtensor(COp): ...@@ -851,7 +851,7 @@ class Subtensor(COp):
inc_spec_pos(1) inc_spec_pos(1)
if depth == 0: if depth == 0:
is_slice.append(0) is_slice.append(0)
elif isinstance(entry, Type): elif isinstance(entry, CType):
init_cmds.append( init_cmds.append(
"subtensor_spec[%i] = %s;" % (spec_pos(), inputs[input_pos()]) "subtensor_spec[%i] = %s;" % (spec_pos(), inputs[input_pos()])
) )
...@@ -1050,7 +1050,7 @@ class Subtensor(COp): ...@@ -1050,7 +1050,7 @@ class Subtensor(COp):
return (9,) return (9,)
def c_code(self, node, name, inputs, outputs, sub): # DEBUG def c_code(self, node, name, inputs, outputs, sub): # DEBUG
if not isinstance(node.inputs[0].type, theano.tensor.TensorType): if not isinstance(node.inputs[0].type, TensorType):
raise NotImplementedError() raise NotImplementedError()
x = inputs[0] x = inputs[0]
...@@ -1469,7 +1469,7 @@ class IncSubtensor(COp): ...@@ -1469,7 +1469,7 @@ class IncSubtensor(COp):
raise IndexError("too many indices for array") raise IndexError("too many indices for array")
input_types = Subtensor.collapse( input_types = Subtensor.collapse(
idx_list, lambda entry: isinstance(entry, gof.Type) idx_list, lambda entry: isinstance(entry, CType)
) )
if len(inputs) != len(input_types): if len(inputs) != len(input_types):
raise IndexError( raise IndexError(
...@@ -1482,7 +1482,7 @@ class IncSubtensor(COp): ...@@ -1482,7 +1482,7 @@ class IncSubtensor(COp):
% (input.type, expected_type) % (input.type, expected_type)
) )
return gof.Apply(self, (x, y) + inputs, [x.type()]) return Apply(self, (x, y) + inputs, [x.type()])
def decl_view(self): def decl_view(self):
return "PyArrayObject * zview = NULL;" return "PyArrayObject * zview = NULL;"
...@@ -1493,7 +1493,7 @@ class IncSubtensor(COp): ...@@ -1493,7 +1493,7 @@ class IncSubtensor(COp):
indices = list(reversed(inputs[2:])) indices = list(reversed(inputs[2:]))
def convert(entry): def convert(entry):
if isinstance(entry, gof.Type): if isinstance(entry, CType):
return indices.pop() return indices.pop()
elif isinstance(entry, slice): elif isinstance(entry, slice):
return slice( return slice(
...@@ -1645,7 +1645,7 @@ class IncSubtensor(COp): ...@@ -1645,7 +1645,7 @@ class IncSubtensor(COp):
""" """
if not isinstance(node.inputs[0].type, theano.tensor.TensorType): if not isinstance(node.inputs[0].type, TensorType):
raise NotImplementedError() raise NotImplementedError()
def c_code_cache_version(self): def c_code_cache_version(self):
...@@ -2239,9 +2239,9 @@ def as_index_variable(idx): ...@@ -2239,9 +2239,9 @@ def as_index_variable(idx):
return NoneConst.clone() return NoneConst.clone()
if isinstance(idx, slice): if isinstance(idx, slice):
return make_slice(idx) return make_slice(idx)
if isinstance(idx, gof.Variable) and isinstance(idx.type, SliceType): if isinstance(idx, Variable) and isinstance(idx.type, SliceType):
return idx return idx
if isinstance(idx, gof.Variable) and isinstance(idx.type, NoneTypeT): if isinstance(idx, Variable) and isinstance(idx.type, NoneTypeT):
return idx return idx
idx = theano.tensor.as_tensor_variable(idx) idx = theano.tensor.as_tensor_variable(idx)
if idx.type.dtype not in theano.tensor.discrete_dtypes: if idx.type.dtype not in theano.tensor.discrete_dtypes:
...@@ -2312,7 +2312,7 @@ class AdvancedSubtensor(Op): ...@@ -2312,7 +2312,7 @@ class AdvancedSubtensor(Op):
for i in indexed_result_shape(fake_shape, bcast_index) for i in indexed_result_shape(fake_shape, bcast_index)
] ]
return gof.Apply( return Apply(
self, self,
(x,) + index, (x,) + index,
[theano.tensor.tensor(dtype=x.type.dtype, broadcastable=bcast)], [theano.tensor.tensor(dtype=x.type.dtype, broadcastable=bcast)],
...@@ -2415,7 +2415,7 @@ class AdvancedIncSubtensor(Op): ...@@ -2415,7 +2415,7 @@ class AdvancedIncSubtensor(Op):
if isinstance(inp, (list, tuple)): if isinstance(inp, (list, tuple)):
inp = theano.tensor.as_tensor_variable(inp) inp = theano.tensor.as_tensor_variable(inp)
new_inputs.append(inp) new_inputs.append(inp)
return gof.Apply( return Apply(
self, self,
(x, y) + tuple(new_inputs), (x, y) + tuple(new_inputs),
[ [
......
...@@ -7,14 +7,14 @@ import theano ...@@ -7,14 +7,14 @@ import theano
from theano import scalar as scal from theano import scalar as scal
from theano.configdefaults import config from theano.configdefaults import config
from theano.gof.graph import Variable from theano.gof.graph import Variable
from theano.gof.type import Type from theano.gof.type import CType
from theano.misc.safe_asarray import _asarray from theano.misc.safe_asarray import _asarray
_logger = logging.getLogger("theano.tensor.type") _logger = logging.getLogger("theano.tensor.type")
class TensorType(Type): class TensorType(CType):
""" """
Symbolic `Type` representing a numpy.ndarray value. Symbolic `Type` representing a numpy.ndarray value.
......
...@@ -7,7 +7,7 @@ import numpy as np ...@@ -7,7 +7,7 @@ import numpy as np
import theano import theano
from theano.gof.graph import Apply, Constant from theano.gof.graph import Apply, Constant
from theano.gof.op import Op from theano.gof.op import Op
from theano.gof.type import Generic, Type from theano.gof.type import CType, Generic
from theano.gradient import DisconnectedType from theano.gradient import DisconnectedType
...@@ -49,7 +49,7 @@ class MakeSlice(Op): ...@@ -49,7 +49,7 @@ class MakeSlice(Op):
make_slice = MakeSlice() make_slice = MakeSlice()
class SliceType(Type): class SliceType(CType):
def filter(self, x, strict=False, allow_downcast=None): def filter(self, x, strict=False, allow_downcast=None):
if isinstance(x, slice): if isinstance(x, slice):
return x return x
......
from theano.gof.type import Type from theano.gof.type import CType, Type
class TypedListType(Type): class TypedListType(CType):
""" """
Parameters Parameters
......
...@@ -119,39 +119,6 @@ def get_unbound_function(unbound): ...@@ -119,39 +119,6 @@ def get_unbound_function(unbound):
return unbound return unbound
class DefaultOrderedDict(OrderedDict):
def __init__(self, default_factory=None, *a, **kw):
if default_factory is not None and not isinstance(default_factory, Callable):
raise TypeError("first argument must be callable")
OrderedDict.__init__(self, *a, **kw)
self.default_factory = default_factory
def __getitem__(self, key):
try:
return OrderedDict.__getitem__(self, key)
except KeyError:
return self.__missing__(key)
def __missing__(self, key):
if self.default_factory is None:
raise KeyError(key)
self[key] = value = self.default_factory()
return value
def __reduce__(self):
if self.default_factory is None:
args = tuple()
else:
args = (self.default_factory,)
return type(self), args, None, None, list(self.items())
def copy(self):
return self.__copy__()
def __copy__(self):
return type(self)(self.default_factory, self)
def maybe_add_to_os_environ_pathlist(var, newpath): def maybe_add_to_os_environ_pathlist(var, newpath):
"""Unfortunately, Conda offers to make itself the default Python """Unfortunately, Conda offers to make itself the default Python
and those who use it that way will probably not activate envs and those who use it that way will probably not activate envs
...@@ -377,22 +344,6 @@ def flatten(a): ...@@ -377,22 +344,6 @@ def flatten(a):
return [a] return [a]
class NoDuplicateOptWarningFilter(logging.Filter):
"""Filter to avoid duplicating optimization warnings."""
prev_msgs = set()
def filter(self, record):
msg = record.getMessage()
if msg.startswith("Optimization Warning: "):
if msg in self.prev_msgs:
return False
else:
self.prev_msgs.add(msg)
return True
return True
def apply_across_args(*fns): def apply_across_args(*fns):
"""Create new functions that distributes the wrapped functions across iterable arguments. """Create new functions that distributes the wrapped functions across iterable arguments.
...@@ -418,3 +369,85 @@ def apply_across_args(*fns): ...@@ -418,3 +369,85 @@ def apply_across_args(*fns):
return partial(f2, fns[0]) return partial(f2, fns[0])
else: else:
return [partial(f2, f) for f in fns] return [partial(f2, f) for f in fns]
class NoDuplicateOptWarningFilter(logging.Filter):
"""Filter to avoid duplicating optimization warnings."""
prev_msgs = set()
def filter(self, record):
msg = record.getMessage()
if msg.startswith("Optimization Warning: "):
if msg in self.prev_msgs:
return False
else:
self.prev_msgs.add(msg)
return True
return True
class Singleton:
"""Convenient base class for a singleton.
It saves having to implement __eq__ and __hash__.
"""
__instance = None
def __new__(cls):
# If sub-subclass of SingletonType don't redeclare __instance
# when we look for it, we will find it in the subclass. We
# don't want that, so we check the class. When we add one, we
# add one only to the current class, so all is working
# correctly.
if cls.__instance is None or not isinstance(cls.__instance, cls):
cls.__instance = super().__new__(cls)
return cls.__instance
def __str__(self):
return self.__class__.__name__
def __eq__(self, other):
if self is other:
return True
if type(self) is type(other):
return True
return False
def __hash__(self):
return hash(type(self))
class DefaultOrderedDict(OrderedDict):
def __init__(self, default_factory=None, *a, **kw):
if default_factory is not None and not isinstance(default_factory, Callable):
raise TypeError("first argument must be callable")
OrderedDict.__init__(self, *a, **kw)
self.default_factory = default_factory
def __getitem__(self, key):
try:
return OrderedDict.__getitem__(self, key)
except KeyError:
return self.__missing__(key)
def __missing__(self, key):
if self.default_factory is None:
raise KeyError(key)
self[key] = value = self.default_factory()
return value
def __reduce__(self):
if self.default_factory is None:
args = tuple()
else:
args = (self.default_factory,)
return type(self), args, None, None, list(self.items())
def copy(self):
return self.__copy__()
def __copy__(self):
return type(self)(self.default_factory, self)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论