提交 25bdeb76 authored 作者: nouiz's avatar nouiz

Merge pull request #145 from jaberg/linalg_misc

misc cleanup and a bugfix in linalg ops
import logging
logger = logging.getLogger(__name__)
import numpy
from theano.gof import Op, Apply
......@@ -57,6 +60,7 @@ def hints(variable):
@local_optimizer([])
def remove_hint_nodes(node):
if is_hint_node(node):
# transfer hints from graph to Feature
try:
for k,v in node.op.hints:
node.env.hints_feature.add_hint(node.inputs[0], k, v)
......@@ -95,7 +99,7 @@ class HintsFeature(object):
"""
def add_hint(self, r, k, v):
print 'adding hint', r, k, v
logger.debug('adding hint; %s, %s, %s' % (r, k, v))
self.hints[r][k] = v
def ensure_init_r(self, r):
......@@ -171,9 +175,8 @@ def is_positive(v):
return True
#TODO: how to handle this - a registry?
# infer_hints on Ops?
print 'is_positive', v
logger.debug('is_positive: %s' % str(v))
if v.owner and v.owner.op == tensor.pow:
print 'try for pow', v, v.owner.inputs
try:
exponent = tensor.get_constant_value(v.owner.inputs[1])
except TypeError:
......@@ -250,7 +253,6 @@ def local_log_prod_sqr(node):
# we cannot always make this substitution because
# the prod might include negative terms
p = x.owner.inputs[0]
print "AAA", p
# p is the matrix we're reducing with prod
if is_positive(p):
......@@ -316,7 +318,7 @@ class Cholesky(Op):
destr = 'destructive'
else:
destr = 'non-destructive'
return 'Cholesky{%s,%s}'% (lu,destr)
return 'Cholesky{%s,%s}' % (lu, destr)
def make_node(self, x):
x = as_tensor_variable(x)
return Apply(self, [x], [x.type()])
......@@ -378,7 +380,10 @@ class Solve(Op):
def make_node(self, A, b):
A = as_tensor_variable(A)
b = as_tensor_variable(b)
return Apply(self, [A,b], [b.type()])
otype = tensor.tensor(
broadcastable=b.broadcastable,
dtype = (A*b).dtype)
return Apply(self, [A,b], [otype])
def perform(self, node, inputs, output_storage):
A, b = inputs
#TODO: use the A_structure to go faster
......@@ -461,6 +466,7 @@ def diag(x):
class Det(Op):
"""matrix determinant
TODO: move this op to another file that request scipy.
"""
def make_node(self, x):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论