提交 c89e1bc2 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #3107 from harlouci/flake8_v3

Flake8 tensor
...@@ -174,3 +174,13 @@ def test_tag_solve_triangular(): ...@@ -174,3 +174,13 @@ def test_tag_solve_triangular():
for node in f.maker.fgraph.toposort(): for node in f.maker.fgraph.toposort():
if isinstance(node.op, Solve): if isinstance(node.op, Solve):
assert node.op.A_structure == 'upper_triangular' assert node.op.A_structure == 'upper_triangular'
def test_matrix_inverse_solve():
if not imported_scipy:
raise SkipTest("Scipy needed for the Solve op.")
A = theano.tensor.dmatrix('A')
b = theano.tensor.dmatrix('b')
node = matrix_inverse(A).dot(b).owner
[out] = inv_as_solve.transform(node)
assert isinstance(out.owner.op, Solve)
...@@ -276,7 +276,7 @@ SOMEPATH/Canopy_64bit/User/lib/python2.7/site-packages/numpy/distutils/system_in ...@@ -276,7 +276,7 @@ SOMEPATH/Canopy_64bit/User/lib/python2.7/site-packages/numpy/distutils/system_in
# Using "conda install mkl" will install both, as well as # Using "conda install mkl" will install both, as well as
# optimized versions of numpy and scipy. # optimized versions of numpy and scipy.
try: try:
import mkl #noqa import mkl # noqa
except ImportError as e: except ImportError as e:
_logger.info('Conda mkl is not available: %s', e) _logger.info('Conda mkl is not available: %s', e)
else: else:
...@@ -333,12 +333,10 @@ try: ...@@ -333,12 +333,10 @@ try:
# `scipy.linalg.blas.fblas` with `scipy.linalg.blas`. # `scipy.linalg.blas.fblas` with `scipy.linalg.blas`.
# See http://github.com/scipy/scipy/pull/358 # See http://github.com/scipy/scipy/pull/358
fblas = scipy.linalg.blas fblas = scipy.linalg.blas
_blas_gemv_fns = { _blas_gemv_fns = {numpy.dtype('float32'): fblas.sgemv,
numpy.dtype('float32'): fblas.sgemv,
numpy.dtype('float64'): fblas.dgemv, numpy.dtype('float64'): fblas.dgemv,
numpy.dtype('complex64'): fblas.cgemv, numpy.dtype('complex64'): fblas.cgemv,
numpy.dtype('complex128'): fblas.zgemv, numpy.dtype('complex128'): fblas.zgemv}
}
except ImportError as e: except ImportError as e:
have_fblas = False have_fblas = False
# This is used in Gemv and ScipyGer. We use CGemv and CGer # This is used in Gemv and ScipyGer. We use CGemv and CGer
...@@ -401,7 +399,7 @@ class Gemv(Op): ...@@ -401,7 +399,7 @@ class Gemv(Op):
# sizes are 1 at time of perform() there is no problem # sizes are 1 at time of perform() there is no problem
# if x.broadcastable[0] != A.broadcastable[1]: # if x.broadcastable[0] != A.broadcastable[1]:
# raise TypeError('broadcastable mismatch between x and A', # raise TypeError('broadcastable mismatch between x and A',
#(x.type, A.type)) # (x.type, A.type))
return Apply(self, [y, alpha, A, x, beta], [y.type()]) return Apply(self, [y, alpha, A, x, beta], [y.type()])
def perform(self, node, inputs, out_storage): def perform(self, node, inputs, out_storage):
...@@ -411,7 +409,8 @@ class Gemv(Op): ...@@ -411,7 +409,8 @@ class Gemv(Op):
gemv = _blas_gemv_fns[y.dtype] gemv = _blas_gemv_fns[y.dtype]
if (A.shape[0] != y.shape[0] or A.shape[1] != x.shape[0]): if (A.shape[0] != y.shape[0] or A.shape[1] != x.shape[0]):
raise ValueError('Incompatible shapes for gemv ' raise ValueError(
'Incompatible shapes for gemv '
'(beta * y + alpha * dot(A, x)). y: %s, A: %s, x: %s ' '(beta * y + alpha * dot(A, x)). y: %s, A: %s, x: %s '
% (y.shape, A.shape, x.shape)) % (y.shape, A.shape, x.shape))
...@@ -571,7 +570,8 @@ def _ldflags(ldflags_str, libs, flags, libs_dir, include_dir): ...@@ -571,7 +570,8 @@ def _ldflags(ldflags_str, libs, flags, libs_dir, include_dir):
if any([f.find(ll) >= 0 for ll in l]): if any([f.find(ll) >= 0 for ll in l]):
found_dyn = True found_dyn = True
if not found_dyn and dirs: if not found_dyn and dirs:
_logger.warning("We did not found a dynamic library into the " _logger.warning(
"We did not found a dynamic library into the "
"library_dir of the library we use for blas. If you use " "library_dir of the library we use for blas. If you use "
"ATLAS, make sure to compile it with dynamics library.") "ATLAS, make sure to compile it with dynamics library.")
...@@ -673,7 +673,7 @@ class GemmRelated(Op): ...@@ -673,7 +673,7 @@ class GemmRelated(Op):
int sx_0, sx_1, sy_0, sy_1, sz_0, sz_1; int sx_0, sx_1, sy_0, sy_1, sz_0, sz_1;
""" """
#setup_z_Nz_Sz = None # setup_z_Nz_Sz = None
check_xyz_rank2 = """ check_xyz_rank2 = """
if (PyArray_NDIM(%(_x)s) != 2) { if (PyArray_NDIM(%(_x)s) != 2) {
...@@ -823,7 +823,7 @@ class GemmRelated(Op): ...@@ -823,7 +823,7 @@ class GemmRelated(Op):
{ {
""" """
#case_float_ab_constants = None # case_float_ab_constants = None
case_float_gemm = """ case_float_gemm = """
float* x = (float*)PyArray_DATA(%(_x)s); float* x = (float*)PyArray_DATA(%(_x)s);
...@@ -856,7 +856,7 @@ class GemmRelated(Op): ...@@ -856,7 +856,7 @@ class GemmRelated(Op):
{ {
""" """
#case_double_ab_constants = None # case_double_ab_constants = None
case_double_gemm = """ case_double_gemm = """
double* x = (double*)PyArray_DATA(%(_x)s); double* x = (double*)PyArray_DATA(%(_x)s);
...@@ -1030,8 +1030,8 @@ class Gemm(GemmRelated): ...@@ -1030,8 +1030,8 @@ class Gemm(GemmRelated):
raise TypeError(Gemm.E_mixed, raise TypeError(Gemm.E_mixed,
(z.dtype, a.dtype, x.dtype, y.dtype, b.dtype)) (z.dtype, a.dtype, x.dtype, y.dtype, b.dtype))
if (not z.dtype.startswith('float') if (not z.dtype.startswith('float') and
and not z.dtype.startswith('complex')): not z.dtype.startswith('complex')):
raise TypeError(Gemm.E_float, (z.dtype)) raise TypeError(Gemm.E_float, (z.dtype))
output = z.type() output = z.type()
...@@ -1173,7 +1173,7 @@ class Gemm(GemmRelated): ...@@ -1173,7 +1173,7 @@ class Gemm(GemmRelated):
_z, _a, _x, _y, _b = inp _z, _a, _x, _y, _b = inp
_zout, = out _zout, = out
if node.inputs[0].type.dtype.startswith('complex'): if node.inputs[0].type.dtype.startswith('complex'):
raise utils.MethodNotDefined('%s.c_code' \ raise utils.MethodNotDefined('%s.c_code'
% self.__class__.__name__) % self.__class__.__name__)
if not config.blas.ldflags: if not config.blas.ldflags:
return super(Gemm, self).c_code(node, name, return super(Gemm, self).c_code(node, name,
...@@ -1203,9 +1203,9 @@ def res_is_a(node, op, maxclients=None): ...@@ -1203,9 +1203,9 @@ def res_is_a(node, op, maxclients=None):
else: else:
retval = True retval = True
return node.owner \ return (node.owner and
and node.owner.op == op \ node.owner.op == op and
and retval retval)
def _as_scalar(res, dtype=None): def _as_scalar(res, dtype=None):
...@@ -1235,16 +1235,16 @@ def _as_scalar(res, dtype=None): ...@@ -1235,16 +1235,16 @@ def _as_scalar(res, dtype=None):
def _is_real_matrix(res): def _is_real_matrix(res):
return res.type.dtype in ('float32', 'float64') \ return (res.type.dtype in ('float32', 'float64') and
and res.type.ndim == 2 \ res.type.ndim == 2 and
and res.type.broadcastable[0] == False \ res.type.broadcastable[0] == False and
and res.type.broadcastable[1] == False # cope with tuple vs. list res.type.broadcastable[1] == False) # cope with tuple vs. list
def _is_real_vector(res): def _is_real_vector(res):
return res.type.dtype in ('float32', 'float64') \ return (res.type.dtype in ('float32', 'float64') and
and res.type.ndim == 1 \ res.type.ndim == 1 and
and res.type.broadcastable[0] == False res.type.broadcastable[0] == False)
def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip=True): def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip=True):
...@@ -1583,7 +1583,7 @@ class GemmOptimizer(Optimizer): ...@@ -1583,7 +1583,7 @@ class GemmOptimizer(Optimizer):
(theano.scalar.Add, theano.scalar.Sub, (theano.scalar.Add, theano.scalar.Sub,
theano.scalar.Neg, theano.scalar.Mul))): theano.scalar.Neg, theano.scalar.Mul))):
continue continue
if not node in fgraph.apply_nodes: if node not in fgraph.apply_nodes:
# This mean that we already removed this node from # This mean that we already removed this node from
# the graph # the graph
continue continue
...@@ -1592,7 +1592,7 @@ class GemmOptimizer(Optimizer): ...@@ -1592,7 +1592,7 @@ class GemmOptimizer(Optimizer):
time_canonicalize += time1 time_canonicalize += time1
time_factor_can += time2 time_factor_can += time2
time_factor_list += time3 time_factor_list += time3
except InconsistencyError as e: except InconsistencyError:
nb_inconsistency_make += 1 nb_inconsistency_make += 1
continue continue
if new_outputs: if new_outputs:
...@@ -1725,7 +1725,7 @@ class Dot22(GemmRelated): ...@@ -1725,7 +1725,7 @@ class Dot22(GemmRelated):
_x, _y = inp _x, _y = inp
_zout, = out _zout, = out
if node.inputs[0].type.dtype.startswith('complex'): if node.inputs[0].type.dtype.startswith('complex'):
raise utils.MethodNotDefined('%s.c_code' \ raise utils.MethodNotDefined('%s.c_code'
% self.__class__.__name__) % self.__class__.__name__)
if len(self.c_libraries()) <= 0: if len(self.c_libraries()) <= 0:
return super(Dot22, self).c_code(node, name, (_x, _y), return super(Dot22, self).c_code(node, name, (_x, _y),
...@@ -1898,8 +1898,7 @@ blas_optdb.register('gemm_optimizer', ...@@ -1898,8 +1898,7 @@ blas_optdb.register('gemm_optimizer',
GemmOptimizer(), GemmOptimizer(),
10, 'fast_run') 10, 'fast_run')
blas_optdb.register('local_gemm_to_gemv', blas_optdb.register('local_gemm_to_gemv',
EquilibriumOptimizer([ EquilibriumOptimizer([local_gemm_to_gemv,
local_gemm_to_gemv,
local_gemm_to_ger, local_gemm_to_ger,
local_dot22_to_ger_or_gemv, local_dot22_to_ger_or_gemv,
local_dimshuffle_lift], local_dimshuffle_lift],
...@@ -1938,8 +1937,8 @@ class Dot22Scalar(GemmRelated): ...@@ -1938,8 +1937,8 @@ class Dot22Scalar(GemmRelated):
raise TypeError('Dot22Scalar requires matching dtypes', raise TypeError('Dot22Scalar requires matching dtypes',
(a.dtype, x.dtype, y.dtype)) (a.dtype, x.dtype, y.dtype))
if (not a.dtype.startswith('float') if (not a.dtype.startswith('float') and
and not a.dtype.startswith('complex')): not a.dtype.startswith('complex')):
raise TypeError('Dot22Scalar requires float or complex args', raise TypeError('Dot22Scalar requires float or complex args',
a.dtype) a.dtype)
...@@ -1992,7 +1991,7 @@ class Dot22Scalar(GemmRelated): ...@@ -1992,7 +1991,7 @@ class Dot22Scalar(GemmRelated):
_x, _y, _a = inp _x, _y, _a = inp
_zout, = out _zout, = out
if node.inputs[0].type.dtype.startswith('complex'): if node.inputs[0].type.dtype.startswith('complex'):
raise utils.MethodNotDefined('%s.c_code' \ raise utils.MethodNotDefined('%s.c_code'
% self.__class__.__name__) % self.__class__.__name__)
if len(self.c_libraries()) <= 0: if len(self.c_libraries()) <= 0:
return super(Dot22Scalar, self).c_code(node, name, (_x, _y), return super(Dot22Scalar, self).c_code(node, name, (_x, _y),
...@@ -2065,8 +2064,7 @@ def local_dot22_to_dot22scalar(node): ...@@ -2065,8 +2064,7 @@ def local_dot22_to_dot22scalar(node):
scalar_idx = -1 scalar_idx = -1
for i, x in enumerate(m.owner.inputs): for i, x in enumerate(m.owner.inputs):
if _as_scalar(x, dtype=d.dtype) and (theano.scalar.upcast( if _as_scalar(x, dtype=d.dtype) and (theano.scalar.upcast(
x.type.dtype, d.type.dtype) x.type.dtype, d.type.dtype) == d.type.dtype):
== d.type.dtype):
scalar_idx = i scalar_idx = i
break break
...@@ -2128,8 +2126,8 @@ blas_optdb.register('local_dot22_to_dot22scalar', ...@@ -2128,8 +2126,8 @@ blas_optdb.register('local_dot22_to_dot22scalar',
11, 'fast_run') 11, 'fast_run')
#from opt import register_specialize, register_canonicalize # from opt import register_specialize, register_canonicalize
#@register_specialize # @register_specialize
@local_optimizer([T.sub, T.add]) @local_optimizer([T.sub, T.add])
def local_print_as_we_go_along(node): def local_print_as_we_go_along(node):
if node.op in (T.sub, T.add): if node.op in (T.sub, T.add):
......
...@@ -4,8 +4,7 @@ from theano import config ...@@ -4,8 +4,7 @@ from theano import config
from theano.tensor.opt import in2out from theano.tensor.opt import in2out
from theano.tensor.blas import ldflags, blas_header_text, blas_header_version from theano.tensor.blas import ldflags, blas_header_text, blas_header_version
from theano.tensor.blas import ( from theano.tensor.blas import blas_optdb, optdb, local_optimizer
blas_optdb, optdb, local_optimizer, EquilibriumOptimizer)
from theano.tensor.blas import Ger, ger, ger_destructive from theano.tensor.blas import Ger, ger, ger_destructive
from theano.tensor.blas import Gemv, gemv_inplace, gemv_no_inplace from theano.tensor.blas import Gemv, gemv_inplace, gemv_no_inplace
from theano.tensor import basic as T from theano.tensor import basic as T
...@@ -732,8 +731,7 @@ def check_force_gemv_init(): ...@@ -732,8 +731,7 @@ def check_force_gemv_init():
gemv_no_inplace(aa, 1., xx, yy, 0.), gemv_no_inplace(aa, 1., xx, yy, 0.),
theano.compile.Mode(optimizer='fast_compile').excluding('gpu', theano.compile.Mode(optimizer='fast_compile').excluding('gpu',
'gpuarray'), 'gpuarray'),
profile=False profile=False)
)
finally: finally:
theano.config.compute_test_value = tv theano.config.compute_test_value = tv
theano.config.compute_test_value_opt = tvo theano.config.compute_test_value_opt = tvo
...@@ -742,11 +740,11 @@ def check_force_gemv_init(): ...@@ -742,11 +740,11 @@ def check_force_gemv_init():
# then we want gemv_c_code to initiliaze the memory to 0 so that we # then we want gemv_c_code to initiliaze the memory to 0 so that we
# don't inadvertantly introduce NaNs to the users data. # don't inadvertantly introduce NaNs to the users data.
aa_data = numpy.array( aa_data = numpy.array(
float('NaN')*numpy.ones((2,)), float('NaN') * numpy.ones((2,)),
dtype=theano.config.floatX dtype=theano.config.floatX
) )
yy_data = numpy.array( yy_data = numpy.array(
numpy.ones((2,))*2, numpy.ones((2,)) * 2,
dtype=theano.config.floatX dtype=theano.config.floatX
) )
xx_data = numpy.array( xx_data = numpy.array(
......
...@@ -311,15 +311,13 @@ def make_reordered_loop(init_loop_orders, olv_index, dtypes, inner_task, sub, op ...@@ -311,15 +311,13 @@ def make_reordered_loop(init_loop_orders, olv_index, dtypes, inner_task, sub, op
total = "%(var)s_n%(candidate)s" % locals() total = "%(var)s_n%(candidate)s" % locals()
break break
else: else:
total = '1'; total = '1'
totals.append(total) totals.append(total)
declare_totals = """ declare_totals = """
int init_totals[%(nnested)s] = {%(totals)s}; int init_totals[%(nnested)s] = {%(totals)s};
""" % dict( """ % dict(nnested=nnested,
nnested=nnested, totals=', '.join(totals))
totals=', '.join(totals)
)
# Sort totals to match the new order that was computed by sorting # Sort totals to match the new order that was computed by sorting
# the loop vector. One integer variable per loop is declared. # the loop vector. One integer variable per loop is declared.
...@@ -355,11 +353,9 @@ def make_reordered_loop(init_loop_orders, olv_index, dtypes, inner_task, sub, op ...@@ -355,11 +353,9 @@ def make_reordered_loop(init_loop_orders, olv_index, dtypes, inner_task, sub, op
declare_strides = """ declare_strides = """
int init_strides[%(nvars)i][%(nnested)i] = { int init_strides[%(nvars)i][%(nnested)i] = {
%(strides)s %(strides)s
};""" % dict( };""" % dict(nvars=nvars,
nvars=nvars,
nnested=nnested, nnested=nnested,
strides=', \n'.join( strides=', \n'.join(', '.join(get_loop_strides(lo, i))
', '.join(get_loop_strides(lo, i))
for i, lo in enumerate(init_loop_orders) for i, lo in enumerate(init_loop_orders)
if len(lo) > 0)) if len(lo) > 0))
...@@ -385,9 +381,9 @@ def make_reordered_loop(init_loop_orders, olv_index, dtypes, inner_task, sub, op ...@@ -385,9 +381,9 @@ def make_reordered_loop(init_loop_orders, olv_index, dtypes, inner_task, sub, op
declare_iter += "%(var)s_iter = (%(dtype)s*)(PyArray_DATA(%(var)s));\n" % locals() declare_iter += "%(var)s_iter = (%(dtype)s*)(PyArray_DATA(%(var)s));\n" % locals()
pointer_update = '' pointer_update = ''
for j , dtype in enumerate(dtypes): for j, dtype in enumerate(dtypes):
var = sub["lv%i" % j] var = sub["lv%i" % j]
pointer_update += "%(dtype)s &%(var)s_i = * ( %(var)s_iter"%locals() pointer_update += "%(dtype)s &%(var)s_i = * ( %(var)s_iter" % locals()
tot_jump = '' tot_jump = ''
for i in reversed(range(nnested)): for i in reversed(range(nnested)):
iterv = 'ITER_%i' % i iterv = 'ITER_%i' % i
...@@ -401,7 +397,7 @@ def make_reordered_loop(init_loop_orders, olv_index, dtypes, inner_task, sub, op ...@@ -401,7 +397,7 @@ def make_reordered_loop(init_loop_orders, olv_index, dtypes, inner_task, sub, op
update = '' update = ''
forloop = '' forloop = ''
# The pointers are defined only in the most inner loop # The pointers are defined only in the most inner loop
if i == nnested-1: if i == nnested - 1:
update = pointer_update update = pointer_update
if i == 0: if i == 0:
if openmp: if openmp:
...@@ -417,15 +413,13 @@ def make_reordered_loop(init_loop_orders, olv_index, dtypes, inner_task, sub, op ...@@ -417,15 +413,13 @@ def make_reordered_loop(init_loop_orders, olv_index, dtypes, inner_task, sub, op
} // end loop %(i)i } // end loop %(i)i
""" % locals() """ % locals()
return '\n'.join([ return '\n'.join(['{',
'{',
order_loops, order_loops,
declare_totals, declare_totals,
declare_strides, declare_strides,
declare_iter, declare_iter,
loop, loop,
'}\n', '}\n'])
])
# print make_declare(((0, 1, 2, 3), ('x', 1, 0, 3), ('x', 'x', 'x', 0)), # print make_declare(((0, 1, 2, 3), ('x', 1, 0, 3), ('x', 'x', 'x', 0)),
# ('double', 'int', 'float'), # ('double', 'int', 'float'),
...@@ -451,16 +445,16 @@ def make_reordered_loop(init_loop_orders, olv_index, dtypes, inner_task, sub, op ...@@ -451,16 +445,16 @@ def make_reordered_loop(init_loop_orders, olv_index, dtypes, inner_task, sub, op
################## ##################
### DimShuffle ### # DimShuffle #
################## ##################
################# #################
### Broadcast ### # Broadcast #
################# #################
################ ################
### CAReduce ### # CAReduce #
################ ################
...@@ -527,4 +521,3 @@ def make_loop_careduce(loop_orders, dtypes, loop_tasks, sub): ...@@ -527,4 +521,3 @@ def make_loop_careduce(loop_orders, dtypes, loop_tasks, sub):
s += loop_tasks[-1] s += loop_tasks[-1]
return "{%s}" % s return "{%s}" % s
...@@ -5,6 +5,7 @@ from six.moves import xrange ...@@ -5,6 +5,7 @@ from six.moves import xrange
import theano import theano
from theano.tensor import basic from theano.tensor import basic
from theano.tensor import nlinalg # noqa
from theano import gof, scalar from theano import gof, scalar
from theano.gradient import DisconnectedType from theano.gradient import DisconnectedType
tensor = basic tensor = basic
......
...@@ -87,8 +87,8 @@ class Fourier(gof.Op): ...@@ -87,8 +87,8 @@ class Fourier(gof.Op):
if len(shape_a) == 1: if len(shape_a) == 1:
return [(n,)] return [(n,)]
elif isinstance(axis, tensor.TensorConstant): elif isinstance(axis, tensor.TensorConstant):
out_shape = list(shape_a[0: axis.data.item()]) + [n] +\ out_shape = (list(shape_a[0: axis.data.item()]) + [n] +
list(shape_a[axis.data + 1:]) list(shape_a[axis.data + 1:]))
else: else:
l = len(shape_a) l = len(shape_a)
shape_a = tensor.stack(*shape_a) shape_a = tensor.stack(*shape_a)
...@@ -136,7 +136,8 @@ class Fourier(gof.Op): ...@@ -136,7 +136,8 @@ class Fourier(gof.Op):
flip_shape = list(numpy.arange(0, a.ndim)[::-1]) flip_shape = list(numpy.arange(0, a.ndim)[::-1])
res = res.dimshuffle(flip_shape) res = res.dimshuffle(flip_shape)
res = tensor.switch(tensor.lt(n, tensor.shape(a)[axis]), res = tensor.switch(tensor.lt(n, tensor.shape(a)[axis]),
tensor.set_subtensor(res[n::, ], 0, False, False), res) tensor.set_subtensor(res[n::, ], 0, False, False),
res)
res = res.dimshuffle(flip_shape) res = res.dimshuffle(flip_shape)
# insures that gradient shape conforms to input shape: # insures that gradient shape conforms to input shape:
......
from __future__ import print_function from __future__ import print_function
import logging import logging
import theano
logger = logging.getLogger(__name__)
import numpy import numpy
from six.moves import xrange from six.moves import xrange
import theano
from theano.tensor import as_tensor_variable
from theano.gof import Op, Apply from theano.gof import Op, Apply
from theano.tensor import as_tensor_variable, dot, DimShuffle, Dot
from theano.tensor.blas import Dot22
from theano.tensor.opt import (register_stabilize,
register_specialize, register_canonicalize)
from theano.gof import local_optimizer
from theano.gof.opt import Optimizer
from theano.gradient import DisconnectedType from theano.gradient import DisconnectedType
from theano.tensor import basic as tensor from theano.tensor import basic as tensor
logger = logging.getLogger(__name__)
class MatrixPinv(Op): class MatrixPinv(Op):
"""Computes the pseudo-inverse of a matrix :math:`A`. """Computes the pseudo-inverse of a matrix :math:`A`.
...@@ -427,8 +423,10 @@ class EighGrad(Op): ...@@ -427,8 +423,10 @@ class EighGrad(Op):
N = x.shape[0] N = x.shape[0]
outer = numpy.outer outer = numpy.outer
G = lambda n: sum(v[:, m] * V.T[n].dot(v[:, m]) / (w[n] - w[m]) def G(n):
return sum(v[:, m] * V.T[n].dot(v[:, m]) / (w[n] - w[m])
for m in xrange(N) if m != n) for m in xrange(N) if m != n)
g = sum(outer(v[:, n], v[:, n] * W[n] + G(n)) g = sum(outer(v[:, n], v[:, n] * W[n] + G(n))
for n in xrange(N)) for n in xrange(N))
...@@ -641,16 +639,6 @@ def svd(a, full_matrices=1, compute_uv=1): ...@@ -641,16 +639,6 @@ def svd(a, full_matrices=1, compute_uv=1):
return SVD(full_matrices, compute_uv)(a) return SVD(full_matrices, compute_uv)(a)
def test_matrix_inverse_solve():
if not imported_scipy:
raise SkipTest("Scipy needed for the Solve op.")
A = theano.tensor.dmatrix('A')
b = theano.tensor.dmatrix('b')
node = matrix_inverse(A).dot(b).owner
[out] = inv_as_solve.transform(node)
assert isinstance(out.owner.op, Solve)
class lstsq(Op): class lstsq(Op):
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) return type(self) == type(other)
...@@ -670,9 +658,6 @@ class lstsq(Op): ...@@ -670,9 +658,6 @@ class lstsq(Op):
theano.tensor.lscalar(), theano.tensor.dvector()]) theano.tensor.lscalar(), theano.tensor.dvector()])
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
x = inputs[0]
y = inputs[1]
rcond = inputs[2]
zz = numpy.linalg.lstsq(inputs[0], inputs[1], inputs[2]) zz = numpy.linalg.lstsq(inputs[0], inputs[1], inputs[2])
outputs[0][0] = zz[0] outputs[0][0] = zz[0]
outputs[1][0] = zz[1] outputs[1][0] = zz[1]
...@@ -703,7 +688,7 @@ def norm(x, ord): ...@@ -703,7 +688,7 @@ def norm(x, ord):
return x[x.nonzero()].shape[0] return x[x.nonzero()].shape[0]
else: else:
try: try:
z = tensor.sum(abs(x**ord))**(1./ord) z = tensor.sum(abs(x**ord))**(1. / ord)
except TypeError: except TypeError:
raise ValueError("Invalid norm order for vectors.") raise ValueError("Invalid norm order for vectors.")
return z return z
......
...@@ -33,7 +33,6 @@ supposed to be canonical. ...@@ -33,7 +33,6 @@ supposed to be canonical.
# TODO: intelligent merge for mul/add # TODO: intelligent merge for mul/add
# TODO: 0*x -> 0 # TODO: 0*x -> 0
import logging import logging
_logger = logging.getLogger('theano.tensor.opt')
from theano import gof from theano import gof
from theano.tensor.elemwise import CAReduce from theano.tensor.elemwise import CAReduce
...@@ -44,6 +43,8 @@ from theano.tensor.basic import (get_scalar_constant_value, ...@@ -44,6 +43,8 @@ from theano.tensor.basic import (get_scalar_constant_value,
from theano.tensor.opt import register_uncanonicalize from theano.tensor.opt import register_uncanonicalize
from theano import scalar as scal from theano import scalar as scal
_logger = logging.getLogger('theano.tensor.opt')
@register_uncanonicalize @register_uncanonicalize
@gof.local_optimizer([T._max_and_argmax]) @gof.local_optimizer([T._max_and_argmax])
...@@ -81,8 +82,8 @@ def local_max_to_min(node): ...@@ -81,8 +82,8 @@ def local_max_to_min(node):
if node.op == T.neg and node.inputs[0].owner: if node.op == T.neg and node.inputs[0].owner:
max = node.inputs[0] max = node.inputs[0]
if (max.owner and if (max.owner and
isinstance(max.owner.op, CAReduce) isinstance(max.owner.op, CAReduce) and
and max.owner.op.scalar_op == scal.maximum): max.owner.op.scalar_op == scal.maximum):
neg = max.owner.inputs[0] neg = max.owner.inputs[0]
if neg.owner and neg.owner.op == T.neg: if neg.owner and neg.owner.op == T.neg:
return [CAReduce(scal.minimum, return [CAReduce(scal.minimum,
......
import logging import logging
_logger = logging.getLogger("theano.tensor.type") import warnings
import numpy import numpy
import theano import theano
from theano import config from theano import config
from theano.gof import Constant, hashtype, Type, Variable from theano.gof import hashtype, Type, Variable
from theano.gof.utils import MethodNotDefined
from theano import scalar as scal from theano import scalar as scal
_logger = logging.getLogger("theano.tensor.type")
class TensorType(Type): class TensorType(Type):
"""Symbolic `Type` representing a numpy.ndarray value.""" """Symbolic `Type` representing a numpy.ndarray value."""
...@@ -39,7 +40,7 @@ class TensorType(Type): ...@@ -39,7 +40,7 @@ class TensorType(Type):
if self.dtype == 'floatX': if self.dtype == 'floatX':
self.dtype = config.floatX self.dtype = config.floatX
# broadcastable is immutable, and all elements are either # broadcastable is immutable, and all elements are either
### True or False # True or False
self.broadcastable = tuple(bool(b) for b in broadcastable) self.broadcastable = tuple(bool(b) for b in broadcastable)
self.dtype_specs() # error checking is done there self.dtype_specs() # error checking is done there
self.name = name self.name = name
...@@ -78,13 +79,13 @@ class TensorType(Type): ...@@ -78,13 +79,13 @@ class TensorType(Type):
'maybe you are trying to call a function on a (possibly ' 'maybe you are trying to call a function on a (possibly '
'shared) variable instead of a numeric array?') 'shared) variable instead of a numeric array?')
if ((type(data) is numpy.ndarray) if ((type(data) is numpy.ndarray) and
and (data.dtype == self.numpy_dtype)): (data.dtype == self.numpy_dtype)):
if data.dtype.num != self.numpy_dtype.num: if data.dtype.num != self.numpy_dtype.num:
data = theano._asarray(data, dtype=self.dtype) data = theano._asarray(data, dtype=self.dtype)
# -- now fall through to ndim check # -- now fall through to ndim check
elif((type(data) is numpy.memmap) elif ((type(data) is numpy.memmap) and
and (data.dtype == self.numpy_dtype)): (data.dtype == self.numpy_dtype)):
# numpy.memmap is a "safe" subclass of ndarray, # numpy.memmap is a "safe" subclass of ndarray,
# so we can use it whereever we expect a base ndarray. # so we can use it whereever we expect a base ndarray.
# however, casting it would defeat the purpose of not # however, casting it would defeat the purpose of not
...@@ -98,8 +99,8 @@ class TensorType(Type): ...@@ -98,8 +99,8 @@ class TensorType(Type):
data, type(data)) data, type(data))
if data.dtype != self.numpy_dtype: if data.dtype != self.numpy_dtype:
raise TypeError(("%s expected a ndarray object with " raise TypeError(("%s expected a ndarray object with "
"dtype = %s (got %s).") % ( "dtype = %s (got %s).") %
self, self.numpy_dtype, data.dtype)) (self, self.numpy_dtype, data.dtype))
assert False, "This point should never be reached." assert False, "This point should never be reached."
else: else:
if allow_downcast: if allow_downcast:
...@@ -210,12 +211,10 @@ class TensorType(Type): ...@@ -210,12 +211,10 @@ class TensorType(Type):
raise TypeError( raise TypeError(
'Cannot convert Type %(othertype)s ' 'Cannot convert Type %(othertype)s '
'(of Variable %(other)s) into Type %(self)s. ' '(of Variable %(other)s) into Type %(self)s. '
'You can try to manually convert %(other)s into a %(self)s.' 'You can try to manually convert %(other)s into a %(self)s.' %
% dict( dict(othertype=other.type,
othertype=other.type,
other=other, other=other,
self=self) self=self))
)
def value_validity_msg(self, a): def value_validity_msg(self, a):
try: try:
...@@ -261,7 +260,7 @@ class TensorType(Type): ...@@ -261,7 +260,7 @@ class TensorType(Type):
and other.broadcastable == self.broadcastable and other.broadcastable == self.broadcastable
def convert_variable(self, var): def convert_variable(self, var):
if (type(self) == type(var.type) and if (type(self) == type(var.type) and # noqa
self.dtype == var.type.dtype and self.dtype == var.type.dtype and
self.ndim == var.type.ndim and self.ndim == var.type.ndim and
all(sb == ob or ob for sb, ob in zip(self.broadcastable, all(sb == ob or ob for sb, ob in zip(self.broadcastable,
...@@ -422,7 +421,7 @@ class TensorType(Type): ...@@ -422,7 +421,7 @@ class TensorType(Type):
def __repr__(self): def __repr__(self):
return str(self) return str(self)
#"TensorType{%s, %s}" % (str(self.dtype), str(self.broadcastable)) # "TensorType{%s, %s}" % (str(self.dtype), str(self.broadcastable))
def c_declare(self, name, sub, check_input=True): def c_declare(self, name, sub, check_input=True):
"""Override `CLinkerType.c_declare` """ """Override `CLinkerType.c_declare` """
......
...@@ -57,17 +57,7 @@ whitelist_flake8 = [ ...@@ -57,17 +57,7 @@ whitelist_flake8 = [
"typed_list/tests/test_type.py", "typed_list/tests/test_type.py",
"typed_list/tests/test_opt.py", "typed_list/tests/test_opt.py",
"typed_list/tests/test_basic.py", "typed_list/tests/test_basic.py",
"tensor/blas_headers.py",
"tensor/type.py",
"tensor/fourier.py",
"tensor/__init__.py", "tensor/__init__.py",
"tensor/opt_uncanonicalize.py",
"tensor/blas.py",
"tensor/extra_ops.py",
"tensor/nlinalg.py",
"tensor/blas_c.py",
"tensor/elemwise_cgen.py",
"tensor/blas_scipy.py",
"tensor/tests/test_subtensor.py", "tensor/tests/test_subtensor.py",
"tensor/tests/test_utils.py", "tensor/tests/test_utils.py",
"tensor/tests/test_nlinalg.py", "tensor/tests/test_nlinalg.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论