提交 7972a4ed authored 作者: Virgile Andreani's avatar Virgile Andreani

Remove frozendict

上级 bf8a1b5a
...@@ -18,8 +18,6 @@ theano/tensor/sharedvar.py: James Bergstra, (c) 2010, Universite de Montreal, 3- ...@@ -18,8 +18,6 @@ theano/tensor/sharedvar.py: James Bergstra, (c) 2010, Universite de Montreal, 3-
theano/gradient.py: James Bergstra, Razvan Pascanu, Arnaud Bergeron, Ian Goodfellow, PyMC Developers, PyTensor Developers, (c) 2011, Universite de Montreal, 3-clause BSD License theano/gradient.py: James Bergstra, Razvan Pascanu, Arnaud Bergeron, Ian Goodfellow, PyMC Developers, PyTensor Developers, (c) 2011, Universite de Montreal, 3-clause BSD License
theano/compile/monitormode.py: this code was initially copied from the 'pyutools' package by its original author, and re-licensed under Theano's license. theano/compile/monitormode.py: this code was initially copied from the 'pyutools' package by its original author, and re-licensed under Theano's license.
Contains frozendict code from slezica’s python-frozendict(https://github.com/slezica/python-frozendict/blob/master/frozendict/__init__.py), Copyright (c) 2012 Santiago Lezica. All rights reserved.
Redistribution and use in source and binary forms, with or without Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met: modification, are permitted provided that the following conditions are met:
......
...@@ -488,7 +488,7 @@ def numba_funcify_Elemwise(op, node, **kwargs): ...@@ -488,7 +488,7 @@ def numba_funcify_Elemwise(op, node, **kwargs):
input_bc_patterns = tuple([inp.type.broadcastable for inp in node.inputs]) input_bc_patterns = tuple([inp.type.broadcastable for inp in node.inputs])
output_bc_patterns = tuple([out.type.broadcastable for out in node.outputs]) output_bc_patterns = tuple([out.type.broadcastable for out in node.outputs])
output_dtypes = tuple(out.type.dtype for out in node.outputs) output_dtypes = tuple(out.type.dtype for out in node.outputs)
inplace_pattern = tuple(op.inplace_pattern.items()) inplace_pattern = op.inplace_pattern
core_output_shapes = tuple(() for _ in range(nout)) core_output_shapes = tuple(() for _ in range(nout))
# numba doesn't support nested literals right now... # numba doesn't support nested literals right now...
......
# License : https://github.com/slezica/python-frozendict/blob/master/LICENSE.txt
import functools
import operator
from collections.abc import Mapping
class frozendict(Mapping):
"""
An immutable wrapper around dictionaries that implements the complete :py:class:`collections.abc.Mapping`
interface. It can be used as a drop-in replacement for dictionaries where immutability and ordering are desired.
"""
dict_cls = dict
def __init__(self, *args, **kwargs):
self._dict = self.dict_cls(*args, **kwargs)
self._hash = None
def __getitem__(self, key):
return self._dict[key]
def __contains__(self, key):
return key in self._dict
def copy(self, **add_or_replace):
return self.__class__(self, **add_or_replace)
def __iter__(self):
return iter(self._dict)
def __len__(self):
return len(self._dict)
def __repr__(self):
return f"<{self.__class__.__name__} {self._dict!r}>"
def __hash__(self):
if self._hash is None:
hashes = map(hash, self.items())
self._hash = functools.reduce(operator.xor, hashes, 0)
return self._hash
...@@ -13,7 +13,6 @@ from pytensor.graph.utils import MethodNotDefined ...@@ -13,7 +13,6 @@ from pytensor.graph.utils import MethodNotDefined
from pytensor.link.c.basic import failure_code from pytensor.link.c.basic import failure_code
from pytensor.link.c.op import COp, ExternalCOp, OpenMPOp from pytensor.link.c.op import COp, ExternalCOp, OpenMPOp
from pytensor.link.c.params_type import ParamsType from pytensor.link.c.params_type import ParamsType
from pytensor.misc.frozendict import frozendict
from pytensor.misc.safe_asarray import _asarray from pytensor.misc.safe_asarray import _asarray
from pytensor.printing import Printer, pprint from pytensor.printing import Printer, pprint
from pytensor.scalar import get_scalar_type from pytensor.scalar import get_scalar_type
...@@ -374,11 +373,11 @@ class Elemwise(OpenMPOp): ...@@ -374,11 +373,11 @@ class Elemwise(OpenMPOp):
""" """
assert not isinstance(scalar_op, type(self)) assert not isinstance(scalar_op, type(self))
if inplace_pattern is None: if inplace_pattern is None:
inplace_pattern = frozendict({}) inplace_pattern = {}
self.name = name self.name = name
self.scalar_op = scalar_op self.scalar_op = scalar_op
self.inplace_pattern = inplace_pattern self.inplace_pattern = tuple(inplace_pattern.items())
self.destroy_map = {o: [i] for o, i in self.inplace_pattern.items()} self.destroy_map = {o: [i] for o, i in self.inplace_pattern}
if nfunc_spec is None: if nfunc_spec is None:
nfunc_spec = getattr(scalar_op, "nfunc_spec", None) nfunc_spec = getattr(scalar_op, "nfunc_spec", None)
...@@ -397,7 +396,6 @@ class Elemwise(OpenMPOp): ...@@ -397,7 +396,6 @@ class Elemwise(OpenMPOp):
super().__setstate__(d) super().__setstate__(d)
self.ufunc = None self.ufunc = None
self.nfunc = None self.nfunc = None
self.inplace_pattern = frozendict(self.inplace_pattern)
def get_output_info(self, dim_shuffle, *inputs): def get_output_info(self, dim_shuffle, *inputs):
"""Return the outputs dtype and broadcastable pattern and the """Return the outputs dtype and broadcastable pattern and the
...@@ -446,9 +444,7 @@ class Elemwise(OpenMPOp): ...@@ -446,9 +444,7 @@ class Elemwise(OpenMPOp):
) )
# inplace_pattern maps output idx -> input idx # inplace_pattern maps output idx -> input idx
inplace_pattern = self.inplace_pattern for overwriter, overwritten in self.inplace_pattern:
if inplace_pattern:
for overwriter, overwritten in inplace_pattern.items():
for out_s, in_s in zip( for out_s, in_s in zip(
out_shapes[overwriter], out_shapes[overwriter],
inputs[overwritten].type.shape, inputs[overwritten].type.shape,
...@@ -460,13 +456,11 @@ class Elemwise(OpenMPOp): ...@@ -460,13 +456,11 @@ class Elemwise(OpenMPOp):
) )
out_dtypes = [o.type.dtype for o in shadow.outputs] out_dtypes = [o.type.dtype for o in shadow.outputs]
if any( if any(inputs[i].type.dtype != out_dtypes[o] for o, i in self.inplace_pattern):
inputs[i].type.dtype != out_dtypes[o] for o, i in inplace_pattern.items()
):
raise TypeError( raise TypeError(
( (
"Cannot do an inplace operation on incompatible data types.", "Cannot do an inplace operation on incompatible data types.",
([i.type.dtype for i in inputs], out_dtypes, inplace_pattern), ([i.type.dtype for i in inputs], out_dtypes, self.inplace_pattern),
) )
) )
assert len(out_dtypes) == len(out_shapes) assert len(out_dtypes) == len(out_shapes)
...@@ -755,6 +749,7 @@ class Elemwise(OpenMPOp): ...@@ -755,6 +749,7 @@ class Elemwise(OpenMPOp):
if nout == 1: if nout == 1:
variables = [variables] variables = [variables]
inplace_pattern = dict(self.inplace_pattern)
for i, (variable, storage, nout) in enumerate( for i, (variable, storage, nout) in enumerate(
zip(variables, output_storage, node.outputs) zip(variables, output_storage, node.outputs)
): ):
...@@ -763,8 +758,8 @@ class Elemwise(OpenMPOp): ...@@ -763,8 +758,8 @@ class Elemwise(OpenMPOp):
# always return an ndarray with dtype object # always return an ndarray with dtype object
variable = np.asarray(variable, dtype=nout.dtype) variable = np.asarray(variable, dtype=nout.dtype)
if i in self.inplace_pattern: if i in inplace_pattern:
odat = inputs[self.inplace_pattern[i]] odat = inputs[inplace_pattern[i]]
odat[...] = variable odat[...] = variable
storage[0] = odat storage[0] = odat
...@@ -832,9 +827,7 @@ class Elemwise(OpenMPOp): ...@@ -832,9 +827,7 @@ class Elemwise(OpenMPOp):
# The destroy map is a map of output indices to input indices # The destroy map is a map of output indices to input indices
# that overwrite them. We just convert them to the actual # that overwrite them. We just convert them to the actual
# Variables. # Variables.
dmap = { dmap = {node.outputs[o]: [node.inputs[i]] for o, i in self.inplace_pattern}
node.outputs[o]: [node.inputs[i]] for o, i in self.inplace_pattern.items()
}
# dtypes of the inputs # dtypes of the inputs
idtypes = [input.type.dtype_specs()[1] for input in inputs] idtypes = [input.type.dtype_specs()[1] for input in inputs]
......
...@@ -173,7 +173,7 @@ class InplaceElemwiseOptimizer(GraphRewriter): ...@@ -173,7 +173,7 @@ class InplaceElemwiseOptimizer(GraphRewriter):
# original node add already some inplace patter and we # original node add already some inplace patter and we
# still try to add more pattern. # still try to add more pattern.
baseline = op.inplace_pattern baseline = dict(op.inplace_pattern)
candidate_outputs = [ candidate_outputs = [
i for i in self.candidate_input_idxs(node) if i not in baseline i for i in self.candidate_input_idxs(node) if i not in baseline
] ]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论