提交 c8394fb3 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #3147 from harlouci/props_compile

Props compile
...@@ -9,6 +9,7 @@ import warnings ...@@ -9,6 +9,7 @@ import warnings
import theano import theano
from theano import gof from theano import gof
from theano.compat import OrderedDict
from six import iteritems from six import iteritems
from six.moves import xrange from six.moves import xrange
...@@ -38,16 +39,11 @@ class ViewOp(gof.Op): ...@@ -38,16 +39,11 @@ class ViewOp(gof.Op):
# In the C code, the name of the input variable is %(iname)s, # In the C code, the name of the input variable is %(iname)s,
# the output variable is %(oname)s. # the output variable is %(oname)s.
c_code_and_version = {} c_code_and_version = {}
__props__ = ()
def make_node(self, x): def make_node(self, x):
return gof.Apply(self, [x], [x.type()]) return gof.Apply(self, [x], [x.type()])
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def perform(self, node, inp, out): def perform(self, node, inp, out):
x, = inp x, = inp
z, = out z, = out
...@@ -138,19 +134,11 @@ class DeepCopyOp(gof.Op): ...@@ -138,19 +134,11 @@ class DeepCopyOp(gof.Op):
c_code_and_version = {} c_code_and_version = {}
check_input = False check_input = False
__props__ = ()
def __init__(self): def __init__(self):
pass pass
def __str__(self):
return self.__class__.__name__
def __hash__(self):
return hash(type(self))
def __eq__(self, other):
return type(self) == type(other)
def make_node(self, x): def make_node(self, x):
return gof.Apply(self, [x], [x.type()]) return gof.Apply(self, [x], [x.type()])
...@@ -228,15 +216,7 @@ class Shape(gof.Op): ...@@ -228,15 +216,7 @@ class Shape(gof.Op):
c_code_and_version = {} c_code_and_version = {}
check_input = False check_input = False
__props__ = ()
def __hash__(self):
return hash(type(self))
def __eq__(self, other):
return type(self) == type(other)
def __str__(self):
return self.__class__.__name__
def make_node(self, x): def make_node(self, x):
# Must work for all type that have a shape attribute. # Must work for all type that have a shape attribute.
...@@ -480,6 +460,7 @@ class FromFunctionOp(gof.Op): ...@@ -480,6 +460,7 @@ class FromFunctionOp(gof.Op):
raise an error if you attempt to get the gradient of a graph raise an error if you attempt to get the gradient of a graph
containing this op. containing this op.
""" """
def __init__(self, fn, itypes, otypes, infer_shape): def __init__(self, fn, itypes, otypes, infer_shape):
self.__fn = fn self.__fn = fn
self.itypes = itypes self.itypes = itypes
...@@ -623,17 +604,21 @@ class Rebroadcast(gof.Op): ...@@ -623,17 +604,21 @@ class Rebroadcast(gof.Op):
c_code_and_version = {} c_code_and_version = {}
check_input = False check_input = False
__props__ = ("axis",)
def __init__(self, *axis): def __init__(self, *axis):
self.axis = dict(axis) # Sort them to make sure we merge all possible case.
items = sorted(axis)
self.axis = OrderedDict(items)
for axis, broad in iteritems(self.axis): for axis, broad in iteritems(self.axis):
assert isinstance(axis, (numpy.integer, int)), ( assert isinstance(axis, (numpy.integer, int)), (
"Rebroadcast needs integer axes. Got ", axis) "Rebroadcast needs integer axes. Got ", axis)
assert isinstance(broad, bool), (
def __eq__(self, other): "Rebroadcast needs bool for new broadcast pattern. Got ",
return type(self) == type(other) and self.axis == other.axis broad)
def __hash__(self): def __hash__(self):
# Need special __hash__ as dict aren't hashable.
# no ambiguity because each item key is unique # no ambiguity because each item key is unique
items = sorted(iteritems(self.axis)) items = sorted(iteritems(self.axis))
return hash((type(self), tuple(items))) return hash((type(self), tuple(items)))
...@@ -768,15 +753,7 @@ class SpecifyShape(gof.Op): ...@@ -768,15 +753,7 @@ class SpecifyShape(gof.Op):
# In the C code, the name of the input variable is %(iname)s, # In the C code, the name of the input variable is %(iname)s,
# the output variable is %(oname)s. # the output variable is %(oname)s.
c_code_and_version = {} c_code_and_version = {}
__props__ = ()
def __hash__(self):
return hash(type(self))
def __eq__(self, other):
return type(self) == type(other)
def __str__(self):
return self.__class__.__name__
def make_node(self, x, shape): def make_node(self, x, shape):
if not isinstance(x, gof.Variable): if not isinstance(x, gof.Variable):
......
import theano import theano
from theano.gof.utils import give_variables_names, unique, remove from theano.compat import OrderedDict
from theano.gof.utils import (
give_variables_names, hash_from_dict, remove, unique)
def test_give_variables_names(): def test_give_variables_names():
...@@ -44,3 +46,19 @@ def test_remove(): ...@@ -44,3 +46,19 @@ def test_remove():
# The list are needed as with python 3, remove and filter return generators # The list are needed as with python 3, remove and filter return generators
# and we can't compare generators. # and we can't compare generators.
assert list(remove(even, range(5))) == list(filter(odd, range(5))) assert list(remove(even, range(5))) == list(filter(odd, range(5)))
def test_hash_from_dict():
dicts = [{}, {0: 0}, {0: 1}, {1: 0}, {1: 1},
{0: (0,)}, {0: [1]},
{0: (0, 1)}, {0: [1, 0]}]
for elem in dicts[:]:
dicts.append(OrderedDict(elem))
hashs = []
for idx, d in enumerate(dicts):
h = hash_from_dict(d)
assert h not in hashs
hashs.append(h)
# List are not hashable. So they are transformed into tuple.
assert hash_from_dict({0: (0,)}) == hash_from_dict({0: [0]})
...@@ -7,7 +7,7 @@ import numpy ...@@ -7,7 +7,7 @@ import numpy
from six import iteritems from six import iteritems
from theano import config from theano import config
from theano.compat import PY3 from theano.compat import OrderedDict, PY3
def simple_extract_stack(f=None, limit=None): def simple_extract_stack(f=None, limit=None):
...@@ -465,3 +465,33 @@ else: ...@@ -465,3 +465,33 @@ else:
def hash_from_file(file_path): def hash_from_file(file_path):
"""Return the MD5 hash of a file.""" """Return the MD5 hash of a file."""
return hash_from_code(open(file_path, 'rb').read()) return hash_from_code(open(file_path, 'rb').read())
def hash_from_dict(d):
"""Work around the fact that dict are not hashable in python
This request that all object have a sorted order that depend only
on the key of the object. We support only integer/float/string keys.
Also, we transform values that are list into tuple as list are not
hashable.
:note: special case for OrderedDict, it use the order of the dict,
so the key don't need to be sortable.
"""
if isinstance(d, OrderedDict):
items = list(iteritems(d))
else:
items = list(d.items())
items.sort()
first_part = [k for k, v in items]
second_part = []
for k, v in items:
assert isinstance(k, (str, int, float))
if isinstance(v, (tuple, list)):
second_part += [tuple(v)]
else:
second_part += [v]
tuple_items = tuple(first_part + second_part + [d.__class__])
return hash(tuple_items)
...@@ -27,11 +27,11 @@ import theano ...@@ -27,11 +27,11 @@ import theano
from six.moves import xrange from six.moves import xrange
from theano.compat import izip from theano.compat import izip
from theano.gof import Op, Apply, local_optimizer, EquilibriumDB from theano.gof import Op, Apply, local_optimizer, EquilibriumDB
from theano.gof.utils import hash_from_dict
from theano.sandbox.cuda import GpuElemwise, CudaNdarrayType, GpuOp from theano.sandbox.cuda import GpuElemwise, CudaNdarrayType, GpuOp
from theano.sandbox.cuda.basic_ops import (as_cuda_ndarray_variable, from theano.sandbox.cuda.basic_ops import (as_cuda_ndarray_variable,
gpu_contiguous) gpu_contiguous)
from theano.sandbox.cuda.opt import gpu_seqopt from theano.sandbox.cuda.opt import gpu_seqopt
from theano.tensor.utils import hash_from_dict
import pycuda import pycuda
from pycuda.compiler import SourceModule from pycuda.compiler import SourceModule
......
...@@ -13,9 +13,9 @@ from theano.gof import Apply, Op, OpenMPOp ...@@ -13,9 +13,9 @@ from theano.gof import Apply, Op, OpenMPOp
from theano import scalar from theano import scalar
from theano.scalar import get_scalar_type from theano.scalar import get_scalar_type
from theano.printing import pprint from theano.printing import pprint
from theano.tensor.utils import hash_from_dict
from theano.gradient import DisconnectedType from theano.gradient import DisconnectedType
from theano.gof.null_type import NullType from theano.gof.null_type import NullType
from theano.gof.utils import hash_from_dict
from theano.tensor import elemwise_cgen as cgen from theano.tensor import elemwise_cgen as cgen
config = theano.config config = theano.config
......
...@@ -3,8 +3,7 @@ import unittest ...@@ -3,8 +3,7 @@ import unittest
import numpy import numpy
import theano import theano
from theano.tensor.utils import (hash_from_ndarray, hash_from_dict, from theano.tensor.utils import (hash_from_ndarray, shape_of_variables)
shape_of_variables)
def test_hash_from_ndarray(): def test_hash_from_ndarray():
...@@ -37,21 +36,6 @@ def test_hash_from_ndarray(): ...@@ -37,21 +36,6 @@ def test_hash_from_ndarray():
assert hash_from_ndarray(rng[::-1]) == hash_from_ndarray(rng[::-1].copy()) assert hash_from_ndarray(rng[::-1]) == hash_from_ndarray(rng[::-1].copy())
def test_hash_from_dict():
dicts = [{}, {0: 0}, {0: 1}, {1: 0}, {1: 1},
{0: (0,)}, {0: [1]},
{0: (0, 1)}, {0: [1, 0]},
]
hashs = []
for idx, d in enumerate(dicts):
h = hash_from_dict(d)
assert h not in hashs
hashs.append(h)
# List are not hashable. So they are transformed into tuple.
assert hash_from_dict({0: (0,)}) == hash_from_dict({0: [0]})
class Tshape_of_variables(unittest.TestCase): class Tshape_of_variables(unittest.TestCase):
def test_simple(self): def test_simple(self):
x = theano.tensor.matrix('x') x = theano.tensor.matrix('x')
......
...@@ -30,31 +30,6 @@ def hash_from_ndarray(data): ...@@ -30,31 +30,6 @@ def hash_from_ndarray(data):
hash_from_code(str(data.dtype))) hash_from_code(str(data.dtype)))
def hash_from_dict(d):
"""Work around the fact that dict are not hashable in python
This request that all object have a sorted order that depend only
on the value of the object. This is true for integer/float/string
We do not verify that the objects in the dict have this property.
Also, we transform values that are list into tuple as list are not
hashable.
"""
items = list(d.items())
items.sort()
first_part = [k for k, v in items]
second_part = []
for k, v in items:
if isinstance(v, (tuple, list)):
second_part += [tuple(v)]
else:
second_part += [v]
tuple_items = tuple(first_part + second_part)
return hash(tuple_items)
def shape_of_variables(fgraph, input_shapes): def shape_of_variables(fgraph, input_shapes):
""" """
Compute the numeric shape of all intermediate variables given input shapes Compute the numeric shape of all intermediate variables given input shapes
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论