提交 0d90f455 authored 作者: James Bergstra's avatar James Bergstra

merge

from .sharedvalue import shared from .sharedvalue import shared, shared_constructor
from .pfunc import pfunc from .pfunc import pfunc
...@@ -187,28 +187,28 @@ class Scalar(Type): ...@@ -187,28 +187,28 @@ class Scalar(Type):
}; };
""" """
operator_eq = """ operator_eq = """
template <> %(mytype)s & %(mytype)s::operator =(const npy_int8 & y) template <> %(mytype)s & %(mytype)s::operator=<npy_int8>(const npy_int8 & y)
{ this->real=y; this->imag=0; return *this; } { this->real=y; this->imag=0; return *this; }
template <> %(mytype)s & %(mytype)s::operator =(const npy_int16 & y) template <> %(mytype)s & %(mytype)s::operator=<npy_int16>(const npy_int16 & y)
{ this->real=y; this->imag=0; return *this; } { this->real=y; this->imag=0; return *this; }
template <> %(mytype)s & %(mytype)s::operator =(const npy_int32 & y) template <> %(mytype)s & %(mytype)s::operator=<npy_int32>(const npy_int32 & y)
{ this->real=y; this->imag=0; return *this; } { this->real=y; this->imag=0; return *this; }
template <> %(mytype)s & %(mytype)s::operator =(const npy_int64 & y) template <> %(mytype)s & %(mytype)s::operator=<npy_int64>(const npy_int64 & y)
{ this->real=y; this->imag=0; return *this; } { this->real=y; this->imag=0; return *this; }
template <> %(mytype)s & %(mytype)s::operator =(const npy_float32 & y) template <> %(mytype)s & %(mytype)s::operator=<npy_float32>(const npy_float32 & y)
{ this->real=y; this->imag=0; return *this; } { this->real=y; this->imag=0; return *this; }
template <> %(mytype)s & %(mytype)s::operator =(const npy_float64 & y) template <> %(mytype)s & %(mytype)s::operator=<npy_float64>(const npy_float64 & y)
{ this->real=y; this->imag=0; return *this; } { this->real=y; this->imag=0; return *this; }
template <> %(mytype)s & %(mytype)s::operator =(const theano_complex128 & y) template <> %(mytype)s & %(mytype)s::operator=<theano_complex128>(const theano_complex128 & y)
{ this->real=y.real; this->imag=y.imag; return *this; } { this->real=y.real; this->imag=y.imag; return *this; }
template <> %(mytype)s & %(mytype)s::operator =(const theano_complex64 & y) template <> %(mytype)s & %(mytype)s::operator=<theano_complex64>(const theano_complex64 & y)
{ this->real=y.real; this->imag=y.imag; return *this; } { this->real=y.real; this->imag=y.imag; return *this; }
""" """
...@@ -219,7 +219,8 @@ class Scalar(Type): ...@@ -219,7 +219,8 @@ class Scalar(Type):
+ operator_eq % dict(mytype='theano_complex64') + operator_eq % dict(mytype='theano_complex64')
def c_code_cache_version(self): def c_code_cache_version(self):
return (2,) return (3,) #explicit T given in specialization of operator= lines. This makes it compile with open64
#2,
int8 = Scalar('int8') int8 = Scalar('int8')
...@@ -666,10 +667,10 @@ class Mul(ScalarOp): ...@@ -666,10 +667,10 @@ class Mul(ScalarOp):
retval = [] retval = []
for input in inputs: for input in inputs:
if input.type in grad_types: if input.type in grad_types:
retval += [mul(*([gz] + utils.difference(inputs, [input])))] retval += [cast(mul(*([gz] + utils.difference(inputs, [input]))), input.type.dtype)]
else: else:
retval += [None] retval += [None]
return retval return retval
#return [(mul(*([gz] + utils.difference(inputs, [input]))) #return [(mul(*([gz] + utils.difference(inputs, [input])))
......
...@@ -1417,15 +1417,19 @@ def mean(input, axis = None): ...@@ -1417,15 +1417,19 @@ def mean(input, axis = None):
if str(input.dtype).startswith('int'): if str(input.dtype).startswith('int'):
# we need to cast eventually anyway, and this helps # we need to cast eventually anyway, and this helps
# to prevents overflow # to prevents overflow
input = convert_to_float64(input) input = cast(input, 'float64')
s = sum(input, axis) s = sum(input, axis)
shp = shape(input) shp = shape(input)
if input.dtype == 'float32':
shp = cast(shp, 'float32')
if axis is None: if axis is None:
axis = range(input.type.ndim) axis = range(input.type.ndim)
elif isinstance(axis, int): elif isinstance(axis, int):
axis = [axis] axis = [axis]
for i in axis: for i in axis:
s = s / shp[i] s = s / shp[i]
if input.dtype.startswith('float'):
assert input.dtype == s.dtype
return s return s
@constructor @constructor
...@@ -2543,12 +2547,15 @@ class Dot(Op): ...@@ -2543,12 +2547,15 @@ class Dot(Op):
def grad(self, (x, y), (gz,)): def grad(self, (x, y), (gz,)):
if gz.type.ndim == 0: if gz.type.ndim == 0:
return gz * y, gz * x rval = gz * y, gz * x
if x.type.ndim == 1 and y.type.ndim > 1: elif x.type.ndim == 1 and y.type.ndim > 1:
return dot(gz, y.T), outer(x.T, gz) rval = dot(gz, y.T), outer(x.T, gz)
if x.type.ndim > 1 and y.type.ndim == 1: elif x.type.ndim > 1 and y.type.ndim == 1:
return outer(gz, y.T), dot(x.T, gz) rval = outer(gz, y.T), dot(x.T, gz)
return dot(gz, y.T), dot(x.T, gz) else:
rval = dot(gz, y.T), dot(x.T, gz)
return cast(rval[0], x.dtype), cast(rval[1], y.dtype)
def __str__(self): def __str__(self):
return "dot" return "dot"
dot = Dot() dot = Dot()
......
"""Ops and optimizations for using BLAS function calls to evaluate linear algebra expressions""" """Ops and optimizations for using BLAS function calls to evaluate linear algebra expressions"""
import os, sys, traceback import os, sys, traceback, logging
import numpy import numpy
from theano.gof import (utils, Op, Apply, view_roots, PatternSub, DestroyHandler, from theano.gof import (utils, Op, Apply, view_roots, PatternSub, DestroyHandler,
...@@ -17,6 +17,13 @@ from theano import compile #to register the optimizer built by this file ...@@ -17,6 +17,13 @@ from theano import compile #to register the optimizer built by this file
from theano.tensor.blas_headers import cblas_header_text, blas_header_text from theano.tensor.blas_headers import cblas_header_text, blas_header_text
_logger = logging.getLogger('theano.tensor.blas')
def debug(*msg): _logger.debug(' '.join(str(m) for m in msg))
def info(*msg): _logger.info(' '.join(str(m) for m in msg))
def warn(*msg): _logger.warn(' '.join(str(m) for m in msg))
def warning(*msg): _logger.warning(' '.join(str(m) for m in msg))
def error(*msg): _logger.error(' '.join(str(m) for m in msg))
@utils.memoize @utils.memoize
def ldflags(libs=True, flags=False): def ldflags(libs=True, flags=False):
"""Return a list of libraries against which an Op's object file should be """Return a list of libraries against which an Op's object file should be
...@@ -655,6 +662,8 @@ def local_dot_to_dot22(node): ...@@ -655,6 +662,8 @@ def local_dot_to_dot22(node):
x,y = node.inputs x,y = node.inputs
if _is_real_matrix(x) and y.type == x.type: if _is_real_matrix(x) and y.type == x.type:
return [_dot22(*node.inputs)] return [_dot22(*node.inputs)]
else:
info('Not optimizing dot with inputs', x, y)
else: else:
return False return False
register_specialize(local_dot_to_dot22) register_specialize(local_dot_to_dot22)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论