提交 c809a6a0 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

PEP8 fixes to the file

上级 f585c34d
...@@ -20,32 +20,42 @@ except ImportError: ...@@ -20,32 +20,42 @@ except ImportError:
# some ops (e.g. Cholesky, Solve, A_Xinv_b) won't work # some ops (e.g. Cholesky, Solve, A_Xinv_b) won't work
imported_scipy = False imported_scipy = False
class Hint(Op): class Hint(Op):
""" """
Provide arbitrary information to the optimizer Provide arbitrary information to the optimizer
These ops are removed from the graph during canonicalization These ops are removed from the graph during canonicalization
in order to not interfere with other optimizations. in order to not interfere with other optimizations.
The idea is that prior to canonicalization, one or more Features of the env should The idea is that prior to canonicalization, one or more Features of the
register the information contained in any Hint node, and transfer that information out of env should register the information contained in any Hint node, and
the graph. transfer that information out of the graph.
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.hints = tuple(kwargs.items()) self.hints = tuple(kwargs.items())
self.view_map = {0:[0]} self.view_map = {0: [0]}
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) and self.hints == other.hints return type(self) == type(other) and self.hints == other.hints
def __hash__(self): def __hash__(self):
return hash((type(self), self.hints)) return hash((type(self), self.hints))
def make_node(self, x): def make_node(self, x):
return Apply(self, [x], [x.type()]) return Apply(self, [x], [x.type()])
def perform(self, node, inputs, outstor): def perform(self, node, inputs, outstor):
outstor[0][0] = inputs[0] outstor[0][0] = inputs[0]
def grad(self, inputs, g_out): def grad(self, inputs, g_out):
return g_out return g_out
def is_hint_node(node): def is_hint_node(node):
return isinstance(node.op, Hint) return isinstance(node.op, Hint)
def hints(variable): def hints(variable):
if hasattr(variable, 'env'): if hasattr(variable, 'env'):
try: try:
...@@ -58,13 +68,14 @@ def hints(variable): ...@@ -58,13 +68,14 @@ def hints(variable):
else: else:
return {} return {}
@register_canonicalize @register_canonicalize
@local_optimizer([]) @local_optimizer([])
def remove_hint_nodes(node): def remove_hint_nodes(node):
if is_hint_node(node): if is_hint_node(node):
# transfer hints from graph to Feature # transfer hints from graph to Feature
try: try:
for k,v in node.op.hints: for k, v in node.op.hints:
node.env.hints_feature.add_hint(node.inputs[0], k, v) node.env.hints_feature.add_hint(node.inputs[0], k, v)
except AttributeError: except AttributeError:
pass pass
...@@ -75,29 +86,34 @@ class HintsFeature(object): ...@@ -75,29 +86,34 @@ class HintsFeature(object):
""" """
Env Feature to track matrix properties Env Feature to track matrix properties
This is a similar feature to variable 'tags'. In fact, tags are one way to provide hints. This is a similar feature to variable 'tags'. In fact, tags are one way
to provide hints.
This class exists because tags were not documented well, and the semantics of how tag This class exists because tags were not documented well, and the
information should be moved around during optimizations was never clearly spelled out. semantics of how tag information should be moved around during
optimizations was never clearly spelled out.
Hints are assumptions about mathematical properties of variables. Hints are assumptions about mathematical properties of variables.
If one variable is substituted for another by an optimization, If one variable is substituted for another by an optimization,
then it means that the assumptions should be transferred to the new variable. then it means that the assumptions should be transferred to the
new variable.
Hints are attached to 'positions in a graph' rather than to variables in particular, Hints are attached to 'positions in a graph' rather than to variables
although Hints are originally attached to a particular positition in a graph *via* a in particular, although Hints are originally attached to a particular
variable in that original graph. positition in a graph *via* a variable in that original graph.
Examples of hints are: Examples of hints are:
- shape information - shape information
- matrix properties (e.g. symmetry, psd, banded, diagonal) - matrix properties (e.g. symmetry, psd, banded, diagonal)
Hint information is propagated through the graph similarly to graph optimizations, Hint information is propagated through the graph similarly to graph
except that adding a hint does not change the graph. Adding a hint is not something that optimizations, except that adding a hint does not change the graph.
debugmode will check. Adding a hint is not something that debugmode will check.
#TODO: should a Hint be an object that can actually evaluate its truthfulness? #TODO: should a Hint be an object that can actually evaluate its
# Should the PSD property be an object that can check the PSD-ness of a variable? # truthfulness?
# Should the PSD property be an object that can check the
# PSD-ness of a variable?
""" """
def add_hint(self, r, k, v): def add_hint(self, r, k, v):
...@@ -107,6 +123,7 @@ class HintsFeature(object): ...@@ -107,6 +123,7 @@ class HintsFeature(object):
def ensure_init_r(self, r): def ensure_init_r(self, r):
if r not in self.hints: if r not in self.hints:
self.hints[r] = {} self.hints[r] = {}
# #
# #
# Feature inteface # Feature inteface
...@@ -115,7 +132,8 @@ class HintsFeature(object): ...@@ -115,7 +132,8 @@ class HintsFeature(object):
def on_attach(self, env): def on_attach(self, env):
assert not hasattr(env, 'hints_feature') assert not hasattr(env, 'hints_feature')
env.hints_feature = self env.hints_feature = self
self.hints = {} # Variable -> tuple(scalars) or None (All tensor vars map to tuple) # Variable -> tuple(scalars) or None (All tensor vars map to tuple)
self.hints = {}
for node in env.toposort(): for node in env.toposort():
self.on_import(env, node) self.on_import(env, node)
...@@ -133,7 +151,7 @@ class HintsFeature(object): ...@@ -133,7 +151,7 @@ class HintsFeature(object):
def update_second_from_first(self, r0, r1): def update_second_from_first(self, r0, r1):
old_hints = self.hints[r0] old_hints = self.hints[r0]
new_hints = self.hints[r1] new_hints = self.hints[r1]
for k,v in old_hints.items(): for k, v in old_hints.items():
if k in new_hints and new_hints[k] is not v: if k in new_hints and new_hints[k] is not v:
raise NotImplementedError() raise NotImplementedError()
if k not in new_hints: if k not in new_hints:
...@@ -151,6 +169,7 @@ class HintsFeature(object): ...@@ -151,6 +169,7 @@ class HintsFeature(object):
# 1) we are trying to get rid of r, or # 1) we are trying to get rid of r, or
# 2) we are putting things back after a failed transaction. # 2) we are putting things back after a failed transaction.
class HintsOptimizer(Optimizer): class HintsOptimizer(Optimizer):
"""Optimizer that serves to add HintsFeature as an env feature. """Optimizer that serves to add HintsFeature as an env feature.
""" """
...@@ -163,7 +182,11 @@ class HintsOptimizer(Optimizer): ...@@ -163,7 +182,11 @@ class HintsOptimizer(Optimizer):
def apply(self, env): def apply(self, env):
pass pass
# -1 should make it run right before the first merge # -1 should make it run right before the first merge
theano.compile.mode.optdb.register('HintsOpt', HintsOptimizer(), -1, 'fast_run', 'fast_compile') theano.compile.mode.optdb.register('HintsOpt',
HintsOptimizer(),
-1,
'fast_run',
'fast_compile')
def psd(v): def psd(v):
...@@ -176,8 +199,12 @@ def psd(v): ...@@ -176,8 +199,12 @@ def psd(v):
def is_psd(v): def is_psd(v):
return hints(v).get('psd', False) return hints(v).get('psd', False)
def is_symmetric(v): def is_symmetric(v):
return hints(v).get('symmetric', False) return hints(v).get('symmetric', False)
def is_positive(v): def is_positive(v):
if hints(v).get('positive', False): if hints(v).get('positive', False):
return True return True
...@@ -200,7 +227,7 @@ def inv_as_solve(node): ...@@ -200,7 +227,7 @@ def inv_as_solve(node):
if not imported_scipy: if not imported_scipy:
return False return False
if node.op == dot: if node.op == dot:
l,r = node.inputs l, r = node.inputs
if l.owner and l.owner.op == matrix_inverse: if l.owner and l.owner.op == matrix_inverse:
return [solve(l.owner.inputs[0], r)] return [solve(l.owner.inputs[0], r)]
if r.owner and r.owner.op == matrix_inverse: if r.owner and r.owner.op == matrix_inverse:
...@@ -209,6 +236,7 @@ def inv_as_solve(node): ...@@ -209,6 +236,7 @@ def inv_as_solve(node):
else: else:
return [solve(r.owner.inputs[0].T, l.T).T] return [solve(r.owner.inputs[0].T, l.T).T]
@register_canonicalize @register_canonicalize
@register_stabilize @register_stabilize
@register_specialize @register_specialize
...@@ -216,16 +244,17 @@ def inv_as_solve(node): ...@@ -216,16 +244,17 @@ def inv_as_solve(node):
def no_transpose_symmetric(node): def no_transpose_symmetric(node):
if isinstance(node.op, DimShuffle): if isinstance(node.op, DimShuffle):
x = node.inputs[0] x = node.inputs[0]
if x.type.ndim==2 and is_symmetric(x): if x.type.ndim == 2 and is_symmetric(x):
#print 'UNDOING TRANSPOSE', is_symmetric(x), x.ndim #print 'UNDOING TRANSPOSE', is_symmetric(x), x.ndim
if node.op.new_order == [1,0]: if node.op.new_order == [1, 0]:
return [x] return [x]
@register_stabilize @register_stabilize
@local_optimizer([]) @local_optimizer([])
def psd_solve_with_chol(node): def psd_solve_with_chol(node):
if node.op == solve: if node.op == solve:
A, b = node.inputs #result is solution Ax=b A, b = node.inputs # result is solution Ax=b
if is_psd(A): if is_psd(A):
L = cholesky(A) L = cholesky(A)
#N.B. this can be further reduced to a yet-unwritten cho_solve Op #N.B. this can be further reduced to a yet-unwritten cho_solve Op
...@@ -235,6 +264,7 @@ def psd_solve_with_chol(node): ...@@ -235,6 +264,7 @@ def psd_solve_with_chol(node):
x = Solve('upper_triangular')(L.T, Li_b) x = Solve('upper_triangular')(L.T, Li_b)
return [x] return [x]
@register_stabilize @register_stabilize
@register_specialize @register_specialize
@local_optimizer([]) @local_optimizer([])
...@@ -246,10 +276,10 @@ def local_det_chol(node): ...@@ -246,10 +276,10 @@ def local_det_chol(node):
""" """
if node.op == det: if node.op == det:
x, = node.inputs x, = node.inputs
for (cl,xpos) in x.clients: for (cl, xpos) in x.clients:
if isinstance(cl.op, Cholesky): if isinstance(cl.op, Cholesky):
L = cl.outputs[0] L = cl.outputs[0]
return [tensor.prod(extract_diag(L)**2)] return [tensor.prod(extract_diag(L) ** 2)]
@register_canonicalize @register_canonicalize
...@@ -268,8 +298,9 @@ def local_log_prod_sqr(node): ...@@ -268,8 +298,9 @@ def local_log_prod_sqr(node):
if is_positive(p): if is_positive(p):
return [tensor.log(p).sum(axis=x.owner.op.axis)] return [tensor.log(p).sum(axis=x.owner.op.axis)]
#TODO: have a reduction like prod and sum that simply returns the sign #TODO: have a reduction like prod and sum that simply
# of the prod multiplication. # returns the sign of the prod multiplication.
@register_canonicalize @register_canonicalize
@register_stabilize @register_stabilize
...@@ -443,6 +474,7 @@ class CholeskyGrad(Op): ...@@ -443,6 +474,7 @@ class CholeskyGrad(Op):
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
return [shapes[0]] return [shapes[0]]
class MatrixPinv(Op): class MatrixPinv(Op):
def __init__(self): def __init__(self):
pass pass
...@@ -459,7 +491,7 @@ class MatrixPinv(Op): ...@@ -459,7 +491,7 @@ class MatrixPinv(Op):
return hash((type(self), self.props())) return hash((type(self), self.props()))
def __eq__(self, other): def __eq__(self, other):
return (type(self)==type(other) and self.props() == other.props()) return (type(self) == type(other) and self.props() == other.props())
def make_node(self, x): def make_node(self, x):
x = as_tensor_variable(x) x = as_tensor_variable(x)
...@@ -474,11 +506,13 @@ class MatrixPinv(Op): ...@@ -474,11 +506,13 @@ class MatrixPinv(Op):
except numpy.linalg.LinAlgError: except numpy.linalg.LinAlgError:
logger.debug('Failed to invert %s' % str(node.inputs[0])) logger.debug('Failed to invert %s' % str(node.inputs[0]))
raise raise
def __str__(self): def __str__(self):
return "MatrixPseudoInverse" return "MatrixPseudoInverse"
pinv = MatrixPinv() pinv = MatrixPinv()
class MatrixInverse(Op): class MatrixInverse(Op):
"""Computes the inverse of a matrix :math:`A`. """Computes the inverse of a matrix :math:`A`.
...@@ -505,7 +539,7 @@ class MatrixInverse(Op): ...@@ -505,7 +539,7 @@ class MatrixInverse(Op):
return hash((type(self), self.props())) return hash((type(self), self.props()))
def __eq__(self, other): def __eq__(self, other):
return (type(self)==type(other) and self.props() == other.props()) return (type(self) == type(other) and self.props() == other.props())
def make_node(self, x): def make_node(self, x):
x = as_tensor_variable(x) x = as_tensor_variable(x)
...@@ -535,7 +569,7 @@ class MatrixInverse(Op): ...@@ -535,7 +569,7 @@ class MatrixInverse(Op):
xi = self(x) xi = self(x)
gz, = g_outputs gz, = g_outputs
#TT.dot(gz.T,xi) #TT.dot(gz.T,xi)
return [-matrix_dot(xi,gz.T,xi).T] return [-matrix_dot(xi, gz.T, xi).T]
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
"""The gradient function should return: """The gradient function should return:
...@@ -555,34 +589,43 @@ class MatrixInverse(Op): ...@@ -555,34 +589,43 @@ class MatrixInverse(Op):
ev, = eval_points ev, = eval_points
if ev is None: if ev is None:
return [None] return [None]
return [-matrix_dot(xi,ev,xi)] return [-matrix_dot(xi, ev, xi)]
def __str__(self): def __str__(self):
return "MatrixInverse" return "MatrixInverse"
matrix_inverse = MatrixInverse() matrix_inverse = MatrixInverse()
class Solve(Op): class Solve(Op):
"""Solve a system of linear equations""" """Solve a system of linear equations"""
def __init__(self, A_structure='general', lower=False, overwrite_A=False, overwrite_b=False): def __init__(self,
A_structure='general',
lower=False,
overwrite_A=False,
overwrite_b=False):
if A_structure not in MATRIX_STRUCTURES: if A_structure not in MATRIX_STRUCTURES:
raise ValueError('Invalid matrix structure argument', A_structure) raise ValueError('Invalid matrix structure argument', A_structure)
self.A_structure = A_structure self.A_structure = A_structure
self.lower=lower self.lower = lower
self.overwrite_A=overwrite_A self.overwrite_A = overwrite_A
self.overwrite_b=overwrite_b self.overwrite_b = overwrite_b
def props(self): def props(self):
return (self.A_structure, return (self.A_structure,
self.lower, self.lower,
self.overwrite_A, self.overwrite_A,
self.overwrite_b) self.overwrite_b)
def __hash__(self): def __hash__(self):
return hash((type(self),self.props())) return hash((type(self), self.props()))
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) and self.props() == other.props() return type(self) == type(other) and self.props() == other.props()
def __repr__(self): def __repr__(self):
return 'Solve{%s}'%str(self.props()) return 'Solve{%s}' % str(self.props())
def make_node(self, A, b): def make_node(self, A, b):
assert imported_scipy, ( assert imported_scipy, (
"Scipy not available. Scipy is needed for the Solve op") "Scipy not available. Scipy is needed for the Solve op")
...@@ -590,30 +633,34 @@ class Solve(Op): ...@@ -590,30 +633,34 @@ class Solve(Op):
b = as_tensor_variable(b) b = as_tensor_variable(b)
otype = tensor.tensor( otype = tensor.tensor(
broadcastable=b.broadcastable, broadcastable=b.broadcastable,
dtype = (A*b).dtype) dtype=(A * b).dtype)
return Apply(self, [A,b], [otype]) return Apply(self, [A, b], [otype])
def perform(self, node, inputs, output_storage): def perform(self, node, inputs, output_storage):
A, b = inputs A, b = inputs
#TODO: use the A_structure to go faster #TODO: use the A_structure to go faster
output_storage[0][0] = scipy.linalg.solve(A,b) output_storage[0][0] = scipy.linalg.solve(A, b)
solve = Solve() # general solve
solve = Solve() # general solve
#TODO : SolveTriangular #TODO : SolveTriangular
#TODO: Optimizations to replace multiplication by matrix inverse with solve() Op (still unwritten) #TODO: Optimizations to replace multiplication by matrix inverse
# with solve() Op (still unwritten)
class ExtractDiag(Op): class ExtractDiag(Op):
""" Return the diagonal of a matrix. """ """ Return the diagonal of a matrix. """
def __init__(self, view=False): def __init__(self, view=False):
self.view = view self.view = view
if self.view: if self.view:
self.view_map = {0:[0]} self.view_map = {0: [0]}
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) and self.view == other.view return type(self) == type(other) and self.view == other.view
def __hash__(self): def __hash__(self):
return hash(type(self))^hash(self.view) return hash(type(self)) ^ hash(self.view)
def make_node(self, _x): def make_node(self, _x):
x = as_tensor_variable(_x) x = as_tensor_variable(_x)
...@@ -622,7 +669,8 @@ class ExtractDiag(Op): ...@@ -622,7 +669,8 @@ class ExtractDiag(Op):
return Apply(self, [x], [tensor.vector(dtype=x.type.dtype)]) return Apply(self, [x], [tensor.vector(dtype=x.type.dtype)])
def perform(self, node, ins, outs): def perform(self, node, ins, outs):
""" For some reason numpy.diag(x) is really slow, so we implemented our own. """ """ For some reason numpy.diag(x) is really slow, so we
implemented our own. """
x, = ins x, = ins
z, = outs z, = outs
...@@ -631,24 +679,26 @@ class ExtractDiag(Op): ...@@ -631,24 +679,26 @@ class ExtractDiag(Op):
z[0] = numpy.zeros(0) z[0] = numpy.zeros(0)
return return
if x.shape[0] < x.shape [1]: if x.shape[0] < x.shape[1]:
rval = x[:,0] rval = x[:, 0]
else: else:
rval = x[0] rval = x[0]
rval.strides = (x.strides[0]+x.strides[1],) rval.strides = (x.strides[0] + x.strides[1],)
if self.view: if self.view:
z[0] = rval z[0] = rval
else: else:
z[0] = rval.copy() z[0] = rval.copy()
def __str__(self): def __str__(self):
return 'ExtractDiag{view=%s}'%self.view return 'ExtractDiag{view=%s}' % self.view
def grad(self, inputs, g_outputs): def grad(self, inputs, g_outputs):
x = tensor.zeros_like(inputs[0]) x = tensor.zeros_like(inputs[0])
xdiag = alloc_diag(g_outputs[0]) xdiag = alloc_diag(g_outputs[0])
return [tensor.set_subtensor(x[:xdiag.shape[0], :xdiag.shape[1]], xdiag)] return [tensor.set_subtensor(
x[:xdiag.shape[0], :xdiag.shape[1]],
xdiag)]
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
x_s, = shapes x_s, = shapes
...@@ -685,7 +735,7 @@ class AllocDiag(Op): ...@@ -685,7 +735,7 @@ class AllocDiag(Op):
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
x_s, = shapes x_s, = shapes
return [(x_s[0],x_s[0])] return [(x_s[0], x_s[0])]
alloc_diag = AllocDiag() alloc_diag = AllocDiag()
...@@ -701,7 +751,7 @@ def diag(x): ...@@ -701,7 +751,7 @@ def diag(x):
xx = as_tensor_variable(x) xx = as_tensor_variable(x)
if xx.type.ndim == 1: if xx.type.ndim == 1:
return alloc_diag(xx) return alloc_diag(xx)
elif xx.type.ndim ==2: elif xx.type.ndim == 2:
return extract_diag(xx) return extract_diag(xx)
else: else:
raise TypeError('diag requires vector or matrix argument', x) raise TypeError('diag requires vector or matrix argument', x)
...@@ -768,7 +818,8 @@ def spectral_radius_bound(X, log2_exponent): ...@@ -768,7 +818,8 @@ def spectral_radius_bound(X, log2_exponent):
XX = tensor.dot(XX, XX) XX = tensor.dot(XX, XX)
return tensor.pow( return tensor.pow(
trace(XX), trace(XX),
2**(-log2_exponent)) 2 ** (-log2_exponent))
class A_Xinv_b(Op): class A_Xinv_b(Op):
"""Product of form a inv(X) b""" """Product of form a inv(X) b"""
...@@ -779,9 +830,10 @@ class A_Xinv_b(Op): ...@@ -779,9 +830,10 @@ class A_Xinv_b(Op):
b = as_tensor_variable(b) b = as_tensor_variable(b)
X = as_tensor_variable(X) X = as_tensor_variable(X)
o = theano.tensor.matrix(dtype=x.dtype) o = theano.tensor.matrix(dtype=x.dtype)
return Apply(self, [a,X,b], [o]) return Apply(self, [a, X, b], [o])
def perform(self, ndoe, inputs, outstor): def perform(self, ndoe, inputs, outstor):
a,X,b = inputs a, X, b = inputs
if 1: if 1:
L_factor = scipy.linalg.cho_factor(X) L_factor = scipy.linalg.cho_factor(X)
xb = scipy.linalg.cho_solve(L_factor, b) xb = scipy.linalg.cho_solve(L_factor, b)
...@@ -789,10 +841,11 @@ class A_Xinv_b(Op): ...@@ -789,10 +841,11 @@ class A_Xinv_b(Op):
z = numpy.dot(xa.T, xb) z = numpy.dot(xa.T, xb)
else: else:
raise NotImplementedError(self.X_structure) raise NotImplementedError(self.X_structure)
outstor[0][0]=z outstor[0][0] = z
def grad(self, inputs, g_outputs): def grad(self, inputs, g_outputs):
gz, = g_outputs gz, = g_outputs
a,X,b = inputs a, X, b = inputs
iX = matrix_inverse(X) iX = matrix_inverse(X)
ga = matrix_dot(gz, b.T, iX.T) ga = matrix_dot(gz, b.T, iX.T)
gX = -matrix_dot(iX.T, a, gz, b.T, iX.T) gX = -matrix_dot(iX.T, a, gz, b.T, iX.T)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论