提交 c8394fb3 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #3147 from harlouci/props_compile

Props compile
......@@ -9,6 +9,7 @@ import warnings
import theano
from theano import gof
from theano.compat import OrderedDict
from six import iteritems
from six.moves import xrange
......@@ -38,16 +39,11 @@ class ViewOp(gof.Op):
# In the C code, the name of the input variable is %(iname)s,
# the output variable is %(oname)s.
c_code_and_version = {}
__props__ = ()
def make_node(self, x):
return gof.Apply(self, [x], [x.type()])
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def perform(self, node, inp, out):
x, = inp
z, = out
......@@ -138,19 +134,11 @@ class DeepCopyOp(gof.Op):
c_code_and_version = {}
check_input = False
__props__ = ()
def __init__(self):
pass
def __str__(self):
return self.__class__.__name__
def __hash__(self):
return hash(type(self))
def __eq__(self, other):
return type(self) == type(other)
def make_node(self, x):
return gof.Apply(self, [x], [x.type()])
......@@ -228,15 +216,7 @@ class Shape(gof.Op):
c_code_and_version = {}
check_input = False
def __hash__(self):
return hash(type(self))
def __eq__(self, other):
return type(self) == type(other)
def __str__(self):
return self.__class__.__name__
__props__ = ()
def make_node(self, x):
# Must work for all type that have a shape attribute.
......@@ -480,6 +460,7 @@ class FromFunctionOp(gof.Op):
raise an error if you attempt to get the gradient of a graph
containing this op.
"""
def __init__(self, fn, itypes, otypes, infer_shape):
self.__fn = fn
self.itypes = itypes
......@@ -623,17 +604,21 @@ class Rebroadcast(gof.Op):
c_code_and_version = {}
check_input = False
__props__ = ("axis",)
def __init__(self, *axis):
self.axis = dict(axis)
# Sort them to make sure we merge all possible case.
items = sorted(axis)
self.axis = OrderedDict(items)
for axis, broad in iteritems(self.axis):
assert isinstance(axis, (numpy.integer, int)), (
"Rebroadcast needs integer axes. Got ", axis)
def __eq__(self, other):
return type(self) == type(other) and self.axis == other.axis
assert isinstance(broad, bool), (
"Rebroadcast needs bool for new broadcast pattern. Got ",
broad)
def __hash__(self):
# Need special __hash__ as dict aren't hashable.
# no ambiguity because each item key is unique
items = sorted(iteritems(self.axis))
return hash((type(self), tuple(items)))
......@@ -768,15 +753,7 @@ class SpecifyShape(gof.Op):
# In the C code, the name of the input variable is %(iname)s,
# the output variable is %(oname)s.
c_code_and_version = {}
def __hash__(self):
return hash(type(self))
def __eq__(self, other):
return type(self) == type(other)
def __str__(self):
return self.__class__.__name__
__props__ = ()
def make_node(self, x, shape):
if not isinstance(x, gof.Variable):
......
import theano
from theano.gof.utils import give_variables_names, unique, remove
from theano.compat import OrderedDict
from theano.gof.utils import (
give_variables_names, hash_from_dict, remove, unique)
def test_give_variables_names():
......@@ -44,3 +46,19 @@ def test_remove():
# The list are needed as with python 3, remove and filter return generators
# and we can't compare generators.
assert list(remove(even, range(5))) == list(filter(odd, range(5)))
def test_hash_from_dict():
dicts = [{}, {0: 0}, {0: 1}, {1: 0}, {1: 1},
{0: (0,)}, {0: [1]},
{0: (0, 1)}, {0: [1, 0]}]
for elem in dicts[:]:
dicts.append(OrderedDict(elem))
hashs = []
for idx, d in enumerate(dicts):
h = hash_from_dict(d)
assert h not in hashs
hashs.append(h)
# List are not hashable. So they are transformed into tuple.
assert hash_from_dict({0: (0,)}) == hash_from_dict({0: [0]})
......@@ -7,7 +7,7 @@ import numpy
from six import iteritems
from theano import config
from theano.compat import PY3
from theano.compat import OrderedDict, PY3
def simple_extract_stack(f=None, limit=None):
......@@ -465,3 +465,33 @@ else:
def hash_from_file(file_path):
"""Return the MD5 hash of a file."""
return hash_from_code(open(file_path, 'rb').read())
def hash_from_dict(d):
"""Work around the fact that dict are not hashable in python
This request that all object have a sorted order that depend only
on the key of the object. We support only integer/float/string keys.
Also, we transform values that are list into tuple as list are not
hashable.
:note: special case for OrderedDict, it use the order of the dict,
so the key don't need to be sortable.
"""
if isinstance(d, OrderedDict):
items = list(iteritems(d))
else:
items = list(d.items())
items.sort()
first_part = [k for k, v in items]
second_part = []
for k, v in items:
assert isinstance(k, (str, int, float))
if isinstance(v, (tuple, list)):
second_part += [tuple(v)]
else:
second_part += [v]
tuple_items = tuple(first_part + second_part + [d.__class__])
return hash(tuple_items)
......@@ -27,11 +27,11 @@ import theano
from six.moves import xrange
from theano.compat import izip
from theano.gof import Op, Apply, local_optimizer, EquilibriumDB
from theano.gof.utils import hash_from_dict
from theano.sandbox.cuda import GpuElemwise, CudaNdarrayType, GpuOp
from theano.sandbox.cuda.basic_ops import (as_cuda_ndarray_variable,
gpu_contiguous)
from theano.sandbox.cuda.opt import gpu_seqopt
from theano.tensor.utils import hash_from_dict
import pycuda
from pycuda.compiler import SourceModule
......
......@@ -13,9 +13,9 @@ from theano.gof import Apply, Op, OpenMPOp
from theano import scalar
from theano.scalar import get_scalar_type
from theano.printing import pprint
from theano.tensor.utils import hash_from_dict
from theano.gradient import DisconnectedType
from theano.gof.null_type import NullType
from theano.gof.utils import hash_from_dict
from theano.tensor import elemwise_cgen as cgen
config = theano.config
......
......@@ -3,8 +3,7 @@ import unittest
import numpy
import theano
from theano.tensor.utils import (hash_from_ndarray, hash_from_dict,
shape_of_variables)
from theano.tensor.utils import (hash_from_ndarray, shape_of_variables)
def test_hash_from_ndarray():
......@@ -37,21 +36,6 @@ def test_hash_from_ndarray():
assert hash_from_ndarray(rng[::-1]) == hash_from_ndarray(rng[::-1].copy())
def test_hash_from_dict():
dicts = [{}, {0: 0}, {0: 1}, {1: 0}, {1: 1},
{0: (0,)}, {0: [1]},
{0: (0, 1)}, {0: [1, 0]},
]
hashs = []
for idx, d in enumerate(dicts):
h = hash_from_dict(d)
assert h not in hashs
hashs.append(h)
# List are not hashable. So they are transformed into tuple.
assert hash_from_dict({0: (0,)}) == hash_from_dict({0: [0]})
class Tshape_of_variables(unittest.TestCase):
def test_simple(self):
x = theano.tensor.matrix('x')
......
......@@ -30,31 +30,6 @@ def hash_from_ndarray(data):
hash_from_code(str(data.dtype)))
def hash_from_dict(d):
"""Work around the fact that dict are not hashable in python
This request that all object have a sorted order that depend only
on the value of the object. This is true for integer/float/string
We do not verify that the objects in the dict have this property.
Also, we transform values that are list into tuple as list are not
hashable.
"""
items = list(d.items())
items.sort()
first_part = [k for k, v in items]
second_part = []
for k, v in items:
if isinstance(v, (tuple, list)):
second_part += [tuple(v)]
else:
second_part += [v]
tuple_items = tuple(first_part + second_part)
return hash(tuple_items)
def shape_of_variables(fgraph, input_shapes):
"""
Compute the numeric shape of all intermediate variables given input shapes
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论