提交 7972a4ed authored 作者: Virgile Andreani's avatar Virgile Andreani

Remove frozendict

上级 bf8a1b5a
......@@ -18,8 +18,6 @@ theano/tensor/sharedvar.py: James Bergstra, (c) 2010, Universite de Montreal, 3-
theano/gradient.py: James Bergstra, Razvan Pascanu, Arnaud Bergeron, Ian Goodfellow, PyMC Developers, PyTensor Developers, (c) 2011, Universite de Montreal, 3-clause BSD License
theano/compile/monitormode.py: this code was initially copied from the 'pyutools' package by its original author, and re-licensed under Theano's license.
Contains frozendict code from slezica’s python-frozendict(https://github.com/slezica/python-frozendict/blob/master/frozendict/__init__.py), Copyright (c) 2012 Santiago Lezica. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
......
......@@ -488,7 +488,7 @@ def numba_funcify_Elemwise(op, node, **kwargs):
input_bc_patterns = tuple([inp.type.broadcastable for inp in node.inputs])
output_bc_patterns = tuple([out.type.broadcastable for out in node.outputs])
output_dtypes = tuple(out.type.dtype for out in node.outputs)
inplace_pattern = tuple(op.inplace_pattern.items())
inplace_pattern = op.inplace_pattern
core_output_shapes = tuple(() for _ in range(nout))
# numba doesn't support nested literals right now...
......
# License : https://github.com/slezica/python-frozendict/blob/master/LICENSE.txt
import functools
import operator
from collections.abc import Mapping
class frozendict(Mapping):
"""
An immutable wrapper around dictionaries that implements the complete :py:class:`collections.abc.Mapping`
interface. It can be used as a drop-in replacement for dictionaries where immutability and ordering are desired.
"""
dict_cls = dict
def __init__(self, *args, **kwargs):
self._dict = self.dict_cls(*args, **kwargs)
self._hash = None
def __getitem__(self, key):
return self._dict[key]
def __contains__(self, key):
return key in self._dict
def copy(self, **add_or_replace):
return self.__class__(self, **add_or_replace)
def __iter__(self):
return iter(self._dict)
def __len__(self):
return len(self._dict)
def __repr__(self):
return f"<{self.__class__.__name__} {self._dict!r}>"
def __hash__(self):
if self._hash is None:
hashes = map(hash, self.items())
self._hash = functools.reduce(operator.xor, hashes, 0)
return self._hash
......@@ -13,7 +13,6 @@ from pytensor.graph.utils import MethodNotDefined
from pytensor.link.c.basic import failure_code
from pytensor.link.c.op import COp, ExternalCOp, OpenMPOp
from pytensor.link.c.params_type import ParamsType
from pytensor.misc.frozendict import frozendict
from pytensor.misc.safe_asarray import _asarray
from pytensor.printing import Printer, pprint
from pytensor.scalar import get_scalar_type
......@@ -374,11 +373,11 @@ class Elemwise(OpenMPOp):
"""
assert not isinstance(scalar_op, type(self))
if inplace_pattern is None:
inplace_pattern = frozendict({})
inplace_pattern = {}
self.name = name
self.scalar_op = scalar_op
self.inplace_pattern = inplace_pattern
self.destroy_map = {o: [i] for o, i in self.inplace_pattern.items()}
self.inplace_pattern = tuple(inplace_pattern.items())
self.destroy_map = {o: [i] for o, i in self.inplace_pattern}
if nfunc_spec is None:
nfunc_spec = getattr(scalar_op, "nfunc_spec", None)
......@@ -397,7 +396,6 @@ class Elemwise(OpenMPOp):
super().__setstate__(d)
self.ufunc = None
self.nfunc = None
self.inplace_pattern = frozendict(self.inplace_pattern)
def get_output_info(self, dim_shuffle, *inputs):
"""Return the outputs dtype and broadcastable pattern and the
......@@ -446,27 +444,23 @@ class Elemwise(OpenMPOp):
)
# inplace_pattern maps output idx -> input idx
inplace_pattern = self.inplace_pattern
if inplace_pattern:
for overwriter, overwritten in inplace_pattern.items():
for out_s, in_s in zip(
out_shapes[overwriter],
inputs[overwritten].type.shape,
):
if in_s == 1 and out_s != 1:
raise ValueError(
"Operation cannot be done inplace on an input "
"with broadcasted dimensions."
)
for overwriter, overwritten in self.inplace_pattern:
for out_s, in_s in zip(
out_shapes[overwriter],
inputs[overwritten].type.shape,
):
if in_s == 1 and out_s != 1:
raise ValueError(
"Operation cannot be done inplace on an input "
"with broadcasted dimensions."
)
out_dtypes = [o.type.dtype for o in shadow.outputs]
if any(
inputs[i].type.dtype != out_dtypes[o] for o, i in inplace_pattern.items()
):
if any(inputs[i].type.dtype != out_dtypes[o] for o, i in self.inplace_pattern):
raise TypeError(
(
"Cannot do an inplace operation on incompatible data types.",
([i.type.dtype for i in inputs], out_dtypes, inplace_pattern),
([i.type.dtype for i in inputs], out_dtypes, self.inplace_pattern),
)
)
assert len(out_dtypes) == len(out_shapes)
......@@ -755,6 +749,7 @@ class Elemwise(OpenMPOp):
if nout == 1:
variables = [variables]
inplace_pattern = dict(self.inplace_pattern)
for i, (variable, storage, nout) in enumerate(
zip(variables, output_storage, node.outputs)
):
......@@ -763,8 +758,8 @@ class Elemwise(OpenMPOp):
# always return an ndarray with dtype object
variable = np.asarray(variable, dtype=nout.dtype)
if i in self.inplace_pattern:
odat = inputs[self.inplace_pattern[i]]
if i in inplace_pattern:
odat = inputs[inplace_pattern[i]]
odat[...] = variable
storage[0] = odat
......@@ -832,9 +827,7 @@ class Elemwise(OpenMPOp):
# The destroy map is a map of output indices to input indices
# that overwrite them. We just convert them to the actual
# Variables.
dmap = {
node.outputs[o]: [node.inputs[i]] for o, i in self.inplace_pattern.items()
}
dmap = {node.outputs[o]: [node.inputs[i]] for o, i in self.inplace_pattern}
# dtypes of the inputs
idtypes = [input.type.dtype_specs()[1] for input in inputs]
......
......@@ -173,7 +173,7 @@ class InplaceElemwiseOptimizer(GraphRewriter):
# original node add already some inplace patter and we
# still try to add more pattern.
baseline = op.inplace_pattern
baseline = dict(op.inplace_pattern)
candidate_outputs = [
i for i in self.candidate_input_idxs(node) if i not in baseline
]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论