提交 19e4cfb9 authored 作者: James Bergstra's avatar James Bergstra

added TensorConstantSignature so that merging works for non-scalar constants

上级 2a085abe
...@@ -178,29 +178,38 @@ class MergeOptimizer(Optimizer): ...@@ -178,29 +178,38 @@ class MergeOptimizer(Optimizer):
def add_requirements(self, env): def add_requirements(self, env):
env.extend(toolbox.ReplaceValidate()) env.extend(toolbox.ReplaceValidate())
def apply(self, env): def apply_constant_merge(self, env):
cid = _metadict() #result -> result.desc() (for constants) const_sig = _metadict() # result -> result.signature() (for constants)
inv_cid = _metadict() #desc -> result (for constants) const_sig_inv = _metadict() # signature -> result (for constants)
for i, r in enumerate([r for r in env.results if isinstance(r, graph.Constant)]): for i, c in enumerate([r for r in env.results if isinstance(r, graph.Constant)]):
sig = r.signature() sig = c.signature()
other_r = inv_cid.get(sig, None) other_c = const_sig_inv.get(sig, None)
if other_r is not None: if other_c is not None:
if r.name: other_r.name = r.name # multiple names will clobber each other..
env.replace_validate(r, other_r) # we adopt convention to keep the last name
if c.name:
other_c.name = c.name
env.replace_validate(c, other_c)
else: else:
cid[r] = sig #this is a new constant
inv_cid[sig] = r const_sig[c] = sig
const_sig_inv[sig] = c
def apply_node_merge(self, env):
# we clear the dicts because the Constants signatures are not necessarily hashable # we clear the dicts because the Constants signatures are not necessarily hashable
# and it's more efficient to give them an integer cid like the other Results # and it's more efficient to give them an integer like the other Results
cid.clear()
inv_cid.clear() symbol_idx = {} #result -> int
symbol_idx_inv = {} #int -> result (inverse of symbol_idx)
#add all graph sources to the symbol_idx dictionaries (arbitrary order)
for i, r in enumerate(r for r in env.results if r.owner is None): for i, r in enumerate(r for r in env.results if r.owner is None):
cid[r] = i symbol_idx[r] = i
inv_cid[i] = r symbol_idx_inv[i] = r
for node in _list_of_nodes(env): for node in _list_of_nodes(env):
node_cid = (node.op, tuple([cid[input] for input in node.inputs])) node_cid = (node.op, tuple([symbol_idx[input] for input in node.inputs]))
dup = inv_cid.get(node_cid, None) dup = symbol_idx_inv.get(node_cid, None)
success = False success = False
if dup is not None: if dup is not None:
success = True success = True
...@@ -213,12 +222,17 @@ class MergeOptimizer(Optimizer): ...@@ -213,12 +222,17 @@ class MergeOptimizer(Optimizer):
except InconsistencyError, e: except InconsistencyError, e:
success = False success = False
if not success: if not success:
cid[node] = node_cid symbol_idx[node] = node_cid
inv_cid[node_cid] = node symbol_idx_inv[node_cid] = node
for i, output in enumerate(node.outputs): for i, output in enumerate(node.outputs):
ref = (i, node_cid) ref = (i, node_cid)
cid[output] = ref symbol_idx[output] = ref
inv_cid[ref] = output symbol_idx_inv[ref] = output
#TODO: Consider splitting this into a separate optimizer (SeqOptimizer)
def apply(self, env):
self.apply_constant_merge(env)
self.apply_node_merge(env)
def MergeOptMerge(opt): def MergeOptMerge(opt):
......
...@@ -337,7 +337,6 @@ class TestMergeOptimizer: ...@@ -337,7 +337,6 @@ class TestMergeOptimizer:
assert strg == '[Op1(y, y)]' or strg == '[Op1(z, z)]' assert strg == '[Op1(y, y)]' or strg == '[Op1(z, z)]'
class TestEquilibrium(object): class TestEquilibrium(object):
def test_1(self): def test_1(self):
......
...@@ -105,7 +105,7 @@ def as_tensor(x, name = None): ...@@ -105,7 +105,7 @@ def as_tensor(x, name = None):
_as_tensor = as_tensor _as_tensor = as_tensor
def constant(x): def constant(x, name=None):
"""Return a symbolic `Constant` with value `x` """Return a symbolic `Constant` with value `x`
:Exceptions: :Exceptions:
...@@ -117,7 +117,7 @@ def constant(x): ...@@ -117,7 +117,7 @@ def constant(x):
x_ = numpy.asarray(x) x_ = numpy.asarray(x)
try: try:
return TensorConstant(Tensor(dtype = x_.dtype, return TensorConstant(Tensor(dtype = x_.dtype,
broadcastable = [d == 1 for d in x_.shape]), x_) broadcastable = [d == 1 for d in x_.shape]), x_, name=name)
except: except:
raise TypeError("Could not convert %s to Tensor" % x, type(x)) raise TypeError("Could not convert %s to Tensor" % x, type(x))
...@@ -554,11 +554,22 @@ class _tensor_py_operators: ...@@ -554,11 +554,22 @@ class _tensor_py_operators:
class TensorResult(Result, _tensor_py_operators): class TensorResult(Result, _tensor_py_operators):
"""Subclass to add the tensor operators to the basic `Result` class.""" """Subclass to add the tensor operators to the basic `Result` class."""
class TensorConstantSignature(tuple):
def __eq__(self, other):
(a, b), (x,y) = self, other
#N.B. compare shape to ensure no broadcasting in ==
return (x == a) and (b.shape == y.shape) and (numpy.all(b == y))
def __hash__(self):
a, b = self
return hash(type(self)) ^ hash(a) ^ hash(b.shape)
class TensorConstant(Constant, _tensor_py_operators): class TensorConstant(Constant, _tensor_py_operators):
"""Subclass to add the tensor operators to the basic `Constant` class. """Subclass to add the tensor operators to the basic `Constant` class.
To create a TensorConstant, use the `constant` function in this module. To create a TensorConstant, use the `constant` function in this module.
""" """
def signature(self):
return TensorConstantSignature((self.type, self.data))
class TensorValue(Value, _tensor_py_operators): class TensorValue(Value, _tensor_py_operators):
"""Subclass to add the tensor operators to the basic `Value` class. """Subclass to add the tensor operators to the basic `Value` class.
......
import numpy
from theano.gof.type import Type
from theano.gof.graph import Result, Apply, Constant
from theano.gof.op import Op
from theano.gof.opt import *
from theano.gof.env import Env
from theano.gof.toolbox import *
import theano.tensor.basic as T
def as_result(x):
if not isinstance(x, Result):
raise TypeError("not a Result", x)
return x
class MyType(Type):
def filter(self, data):
return data
def __eq__(self, other):
return isinstance(other, MyType)
class MyOp(Op):
def __init__(self, name, dmap = {}, x = None):
self.name = name
self.destroy_map = dmap
self.x = x
def make_node(self, *inputs):
inputs = map(as_result, inputs)
for input in inputs:
if not isinstance(input.type, MyType):
raise Exception("Error 1")
outputs = [MyType()()]
return Apply(self, inputs, outputs)
def __str__(self):
return self.name
def __repr__(self):
return self.name
def __eq__(self, other):
return self is other or isinstance(other, MyOp) and self.x is not None and self.x == other.x
def __hash__(self):
return self.x if self.x is not None else id(self)
op1 = MyOp('Op1')
def test_merge_with_weird_eq():
"""numpy arrays don't compare equal like other python objects"""
#SCALAR CASE
x = T.constant(numpy.asarray(1), name='x')
y = T.constant(numpy.asarray(1), name='y')
g = Env([x, y], [x+y])
MergeOptimizer().optimize(g)
assert len(g.nodes) == 1
node = list(g.nodes)[0]
assert len(node.inputs) == 2
assert node.inputs[0] is node.inputs[1]
#NONSCALAR CASE
# This was created to test TensorConstantSignature
x = T.constant(numpy.ones(5), name='x')
y = T.constant(numpy.ones(5), name='y')
g = Env([x, y], [x+y])
MergeOptimizer().optimize(g)
assert len(g.nodes) == 1
node = list(g.nodes)[0]
assert len(node.inputs) == 2
assert node.inputs[0] is node.inputs[1]
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论