提交 706ef19b authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename object2 to MetaObject and MetaObject to MetaType

上级 604d622f
...@@ -43,4 +43,4 @@ from theano.gof.toolbox import ( ...@@ -43,4 +43,4 @@ from theano.gof.toolbox import (
Validator, Validator,
) )
from theano.gof.type import CEnumType, EnumList, EnumType, Generic, Type, generic from theano.gof.type import CEnumType, EnumList, EnumType, Generic, Type, generic
from theano.gof.utils import MethodNotDefined, object2 from theano.gof.utils import MetaObject, MethodNotDefined
...@@ -41,7 +41,7 @@ class MissingInputError(Exception): ...@@ -41,7 +41,7 @@ class MissingInputError(Exception):
Exception.__init__(self, s) Exception.__init__(self, s)
class FunctionGraph(utils.object2): class FunctionGraph(utils.MetaObject):
""" """
A `FunctionGraph` represents a subgraph bound by a set of input variables and A `FunctionGraph` represents a subgraph bound by a set of input variables and
a set of output variables, ie a subgraph that specifies a theano function. a set of output variables, ie a subgraph that specifies a theano function.
......
...@@ -12,13 +12,13 @@ import numpy as np ...@@ -12,13 +12,13 @@ import numpy as np
import theano import theano
from theano.configdefaults import config from theano.configdefaults import config
from theano.gof.utils import ( from theano.gof.utils import (
MetaObject,
MethodNotDefined, MethodNotDefined,
Scratchpad, Scratchpad,
TestValueError, TestValueError,
ValidatingScratchpad, ValidatingScratchpad,
add_tag_trace, add_tag_trace,
get_variable_trace_string, get_variable_trace_string,
object2,
) )
from theano.misc.ordered_set import OrderedSet from theano.misc.ordered_set import OrderedSet
...@@ -28,7 +28,7 @@ __docformat__ = "restructuredtext en" ...@@ -28,7 +28,7 @@ __docformat__ = "restructuredtext en"
NoParams = object() NoParams = object()
class Node(object2): class Node(MetaObject):
"""A `Node` in a Theano graph. """A `Node` in a Theano graph.
Currently, graphs contain two kinds of `Nodes`: `Variable`s and `Apply`s. Currently, graphs contain two kinds of `Nodes`: `Variable`s and `Apply`s.
......
...@@ -21,11 +21,11 @@ from theano.gof import graph ...@@ -21,11 +21,11 @@ from theano.gof import graph
from theano.gof.fg import FunctionGraph from theano.gof.fg import FunctionGraph
from theano.gof.graph import Apply, Variable from theano.gof.graph import Apply, Variable
from theano.gof.utils import ( from theano.gof.utils import (
MetaObject,
MethodNotDefined, MethodNotDefined,
TestValueError, TestValueError,
add_tag_trace, add_tag_trace,
get_variable_trace_string, get_variable_trace_string,
object2,
) )
from theano.link.c.interface import CLinkerOp from theano.link.c.interface import CLinkerOp
...@@ -118,7 +118,7 @@ def compute_test_value(node): ...@@ -118,7 +118,7 @@ def compute_test_value(node):
output.tag.test_value = storage_map[output][0] output.tag.test_value = storage_map[output][0]
class Op(object2): class Op(MetaObject):
"""A class that models and constructs operations in a graph. """A class that models and constructs operations in a graph.
A `Op` instance has several responsibilities: A `Op` instance has several responsibilities:
......
...@@ -14,7 +14,7 @@ import theano ...@@ -14,7 +14,7 @@ import theano
from theano.configdefaults import config from theano.configdefaults import config
from theano.gof import graph, utils from theano.gof import graph, utils
from theano.gof.op import COp from theano.gof.op import COp
from theano.gof.utils import MethodNotDefined, object2 from theano.gof.utils import MetaObject, MethodNotDefined
from theano.link.c.interface import CLinkerType from theano.link.c.interface import CLinkerType
...@@ -221,7 +221,7 @@ _nothing = """ ...@@ -221,7 +221,7 @@ _nothing = """
""" """
class Type(object2, PureType, CLinkerType): class Type(MetaObject, PureType, CLinkerType):
""" """
Convenience wrapper combining `PureType` and `CLinkerType`. Convenience wrapper combining `PureType` and `CLinkerType`.
......
...@@ -157,7 +157,7 @@ class MethodNotDefined(Exception): ...@@ -157,7 +157,7 @@ class MethodNotDefined(Exception):
""" """
class MetaObject(type): class MetaType(type):
def __new__(cls, name, bases, dct): def __new__(cls, name, bases, dct):
props = dct.get("__props__", None) props = dct.get("__props__", None)
if props is not None: if props is not None:
...@@ -223,7 +223,7 @@ class MetaObject(type): ...@@ -223,7 +223,7 @@ class MetaObject(type):
return type.__new__(cls, name, bases, dct) return type.__new__(cls, name, bases, dct)
class object2(metaclass=MetaObject): class MetaObject(metaclass=MetaType):
__slots__ = [] __slots__ = []
def __ne__(self, other): def __ne__(self, other):
......
...@@ -22,11 +22,17 @@ import numpy as np ...@@ -22,11 +22,17 @@ import numpy as np
import theano import theano
from theano import gof, printing from theano import gof, printing
from theano.configdefaults import config from theano.configdefaults import config
from theano.gof import utils
from theano.gof.fg import FunctionGraph from theano.gof.fg import FunctionGraph
from theano.gof.graph import Apply, Constant, Variable from theano.gof.graph import Apply, Constant, Variable
from theano.gof.op import COp from theano.gof.op import COp
from theano.gof.type import Type from theano.gof.type import Type
from theano.gof.utils import (
MetaObject,
MethodNotDefined,
difference,
from_return_values,
to_return_values,
)
from theano.gradient import DisconnectedType, grad_undefined from theano.gradient import DisconnectedType, grad_undefined
from theano.misc.safe_asarray import _asarray from theano.misc.safe_asarray import _asarray
from theano.printing import pprint from theano.printing import pprint
...@@ -954,7 +960,7 @@ def same_out_float_only(type): ...@@ -954,7 +960,7 @@ def same_out_float_only(type):
return (type,) return (type,)
class transfer_type(gof.utils.object2): class transfer_type(MetaObject):
__props__ = ("transfer",) __props__ = ("transfer",)
def __init__(self, *transfer): def __init__(self, *transfer):
...@@ -978,7 +984,7 @@ class transfer_type(gof.utils.object2): ...@@ -978,7 +984,7 @@ class transfer_type(gof.utils.object2):
# return [upcast if i is None else types[i] for i in self.transfer] # return [upcast if i is None else types[i] for i in self.transfer]
class specific_out(gof.utils.object2): class specific_out(MetaObject):
__props__ = ("spec",) __props__ = ("spec",)
def __init__(self, *spec): def __init__(self, *spec):
...@@ -1027,7 +1033,7 @@ def float_out_nocomplex(*types): ...@@ -1027,7 +1033,7 @@ def float_out_nocomplex(*types):
return (float64,) return (float64,)
class unary_out_lookup(gof.utils.object2): class unary_out_lookup(MetaObject):
""" """
Get a output_types_preference object by passing a dictionary: Get a output_types_preference object by passing a dictionary:
...@@ -1126,16 +1132,16 @@ class ScalarOp(COp): ...@@ -1126,16 +1132,16 @@ class ScalarOp(COp):
if self.nout == 1: if self.nout == 1:
output_storage[0][0] = self.impl(*inputs) output_storage[0][0] = self.impl(*inputs)
else: else:
variables = utils.from_return_values(self.impl(*inputs)) variables = from_return_values(self.impl(*inputs))
assert len(variables) == len(output_storage) assert len(variables) == len(output_storage)
for storage, variable in zip(output_storage, variables): for storage, variable in zip(output_storage, variables):
storage[0] = variable storage[0] = variable
def impl(self, *inputs): def impl(self, *inputs):
raise utils.MethodNotDefined("impl", type(self), self.__class__.__name__) raise MethodNotDefined("impl", type(self), self.__class__.__name__)
def grad(self, inputs, output_gradients): def grad(self, inputs, output_gradients):
raise utils.MethodNotDefined("grad", type(self), self.__class__.__name__) raise MethodNotDefined("grad", type(self), self.__class__.__name__)
def L_op(self, inputs, outputs, output_gradients): def L_op(self, inputs, outputs, output_gradients):
return self.grad(inputs, output_gradients) return self.grad(inputs, output_gradients)
...@@ -1183,7 +1189,7 @@ class ScalarOp(COp): ...@@ -1183,7 +1189,7 @@ class ScalarOp(COp):
the inputs/outputs types) the inputs/outputs types)
""" """
raise theano.gof.utils.MethodNotDefined() raise MethodNotDefined()
def supports_c_code(self, inputs, outputs): def supports_c_code(self, inputs, outputs):
"""Returns True if the current op has functioning C code for """Returns True if the current op has functioning C code for
...@@ -1215,7 +1221,7 @@ class ScalarOp(COp): ...@@ -1215,7 +1221,7 @@ class ScalarOp(COp):
["z" for z in outputs], ["z" for z in outputs],
{"fail": "%(fail)s"}, {"fail": "%(fail)s"},
) )
except (theano.gof.utils.MethodNotDefined, NotImplementedError): except (MethodNotDefined, NotImplementedError):
return False return False
return True return True
...@@ -1235,7 +1241,7 @@ class UnaryScalarOp(ScalarOp): ...@@ -1235,7 +1241,7 @@ class UnaryScalarOp(ScalarOp):
# as this function do not broadcast # as this function do not broadcast
node.inputs[0].type != node.outputs[0].type node.inputs[0].type != node.outputs[0].type
): ):
raise theano.gof.utils.MethodNotDefined() raise MethodNotDefined()
dtype = node.inputs[0].type.dtype_specs()[1] dtype = node.inputs[0].type.dtype_specs()[1]
fct_call = self.c_code_contiguous_raw(dtype, "n", "x", "z") fct_call = self.c_code_contiguous_raw(dtype, "n", "x", "z")
...@@ -1250,7 +1256,7 @@ class UnaryScalarOp(ScalarOp): ...@@ -1250,7 +1256,7 @@ class UnaryScalarOp(ScalarOp):
def c_code_contiguous_raw(self, dtype, n, i, o): def c_code_contiguous_raw(self, dtype, n, i, o):
if not config.lib__amblibm: if not config.lib__amblibm:
raise theano.gof.utils.MethodNotDefined() raise MethodNotDefined()
if dtype.startswith("npy_"): if dtype.startswith("npy_"):
dtype = dtype[4:] dtype = dtype[4:]
if dtype == "float32" and self.amd_float32 is not None: if dtype == "float32" and self.amd_float32 is not None:
...@@ -1260,7 +1266,7 @@ class UnaryScalarOp(ScalarOp): ...@@ -1260,7 +1266,7 @@ class UnaryScalarOp(ScalarOp):
dtype = "double" dtype = "double"
fct = self.amd_float64 fct = self.amd_float64
else: else:
raise theano.gof.utils.MethodNotDefined() raise MethodNotDefined()
return f"{fct}({n}, {i}, {o})" return f"{fct}({n}, {i}, {o})"
...@@ -1895,7 +1901,7 @@ class Mul(ScalarOp): ...@@ -1895,7 +1901,7 @@ class Mul(ScalarOp):
if gz.type in complex_types: if gz.type in complex_types:
# zr+zi = (xr + xi)(yr + yi) # zr+zi = (xr + xi)(yr + yi)
# zr+zi = (xr*yr - xi*yi) + (xr yi + xi yr ) # zr+zi = (xr*yr - xi*yi) + (xr yi + xi yr )
otherprod = mul(*(utils.difference(inputs, [input]))) otherprod = mul(*(difference(inputs, [input])))
yr = real(otherprod) yr = real(otherprod)
yi = imag(otherprod) yi = imag(otherprod)
if input.type in complex_types: if input.type in complex_types:
...@@ -1907,7 +1913,7 @@ class Mul(ScalarOp): ...@@ -1907,7 +1913,7 @@ class Mul(ScalarOp):
else: else:
retval += [yr * real(gz) + yi * imag(gz)] retval += [yr * real(gz) + yi * imag(gz)]
else: else:
retval += [mul(*([gz] + utils.difference(inputs, [input])))] retval += [mul(*([gz] + difference(inputs, [input])))]
return retval return retval
...@@ -2269,7 +2275,7 @@ class Pow(BinaryScalarOp): ...@@ -2269,7 +2275,7 @@ class Pow(BinaryScalarOp):
(x, y) = inputs (x, y) = inputs
(z,) = outputs (z,) = outputs
if not config.lib__amblibm: if not config.lib__amblibm:
raise theano.gof.utils.MethodNotDefined() raise MethodNotDefined()
# We compare the dtype AND the broadcast flag # We compare the dtype AND the broadcast flag
# as this function do not broadcast # as this function do not broadcast
...@@ -2310,7 +2316,7 @@ class Pow(BinaryScalarOp): ...@@ -2310,7 +2316,7 @@ class Pow(BinaryScalarOp):
{fct}(n, x, *y, z); {fct}(n, x, *y, z);
""" """
raise theano.gof.utils.MethodNotDefined() raise MethodNotDefined()
pow = Pow(upcast_out_min8, name="pow") pow = Pow(upcast_out_min8, name="pow")
...@@ -4224,7 +4230,7 @@ class Composite(ScalarOp): ...@@ -4224,7 +4230,7 @@ class Composite(ScalarOp):
def impl(self, *inputs): def impl(self, *inputs):
output_storage = [[None] for i in range(self.nout)] output_storage = [[None] for i in range(self.nout)]
self.perform(None, inputs, output_storage) self.perform(None, inputs, output_storage)
ret = utils.to_return_values([storage[0] for storage in output_storage]) ret = to_return_values([storage[0] for storage in output_storage])
if self.nout > 1: if self.nout > 1:
ret = tuple(ret) ret = tuple(ret)
return ret return ret
...@@ -4265,7 +4271,7 @@ class Composite(ScalarOp): ...@@ -4265,7 +4271,7 @@ class Composite(ScalarOp):
for subnode in self.fgraph.toposort(): for subnode in self.fgraph.toposort():
try: try:
rval.append(subnode.op.c_support_code().strip()) rval.append(subnode.op.c_support_code().strip())
except gof.utils.MethodNotDefined: except MethodNotDefined:
pass pass
# remove duplicate code blocks # remove duplicate code blocks
return "\n".join(sorted(set(rval))) return "\n".join(sorted(set(rval)))
...@@ -4280,7 +4286,7 @@ class Composite(ScalarOp): ...@@ -4280,7 +4286,7 @@ class Composite(ScalarOp):
) )
if subnode_support_code: if subnode_support_code:
rval.append(subnode_support_code) rval.append(subnode_support_code)
except gof.utils.MethodNotDefined: except MethodNotDefined:
pass pass
# there should be no need to remove duplicate code blocks because # there should be no need to remove duplicate code blocks because
# each block should have been specialized for the given nodename. # each block should have been specialized for the given nodename.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论