提交 628b35b4 authored 作者: James Bergstra's avatar James Bergstra

merge

......@@ -463,7 +463,12 @@ def _is_function_output(node):
def _is_used_in_graph(node):
return not(_is_function_output(node) or node.clients==[])
def _check_strides_match(a, b, raise_on_err, op):
def _check_strides_match(a, b, warn_err, op):
"""
param: warn_err: if 0, no warning, if 1 warning, if 2 error
"""
if warn_err==0: return
try:
strides_eq = a.strides == b.strides
except:
......@@ -471,7 +476,7 @@ def _check_strides_match(a, b, raise_on_err, op):
if not strides_eq:
e = TypeError('Stride mismatch', (a.shape, b.shape, a.strides, b.strides, str(op)))
if raise_on_err:
if warn_err==2:
raise e
else:
print >> sys.stderr, 'WARNING:', e
......@@ -1370,10 +1375,10 @@ class DebugMode(Mode):
Should we check for (and complain about) NaN/Inf ndarray elements?
"""
require_matching_strides = bool(int(os.getenv('THEANO_DEBUGMODE_CHECK_STRIDES', 0)))
require_matching_strides = bool(int(os.getenv('THEANO_DEBUGMODE_CHECK_STRIDES', 1)))
"""
Should we check for (and complain about) Ops whose python and C outputs are ndarrays with
different strides? (This can catch bugs, but is generally overly strict.)
different strides? (This can catch bugs, but is generally overly strict.) 0 no check, 1 warn, 2 err.
"""
# This function will be used to create a FunctionMaker in
......
......@@ -2570,6 +2570,8 @@ def tile(x, reps, ndim=None):
class Dot(Op):
"""Compute matrix-matrix, matrix-vector products and vector inner-products.
:note: matrix-matrix products are sometimes optimized to Dot22 ops (see tensor.blas)
"""
def make_node(self, *inputs):
inputs = map(as_tensor_variable, inputs)
......@@ -2613,6 +2615,8 @@ class Dot(Op):
def perform(self, node, (x, y), (z, )):
try:
# the asarray is here because dot between two vectors gives a numpy float object
# but we need to return a 0d ndarray
z[0] = numpy.asarray(numpy.dot(x, y))
except ValueError, e:
# The error raised by numpy has no shape information, we mean to add that
......
......@@ -10,7 +10,6 @@ from elemwise import Elemwise, DimShuffle
from theano import scalar
import basic as T
import inplace as I
import numpy
import numpy as N
import operator
import itertools
......@@ -873,9 +872,9 @@ def local_mul_zero(node):
except TypeError:
continue
#print 'MUL by value', value, node.inputs
if numpy.all(value == 0):
if N.all(value == 0):
#print '... returning zeros'
return _fill_chain(numpy.asarray(0, dtype=otype.dtype), node.inputs)
return _fill_chain(N.asarray(0, dtype=otype.dtype), node.inputs)
register_canonicalize(local_mul_zero)
@gof.local_optimizer([T.true_div])
......@@ -982,7 +981,7 @@ def local_mul_specialize(node):
return False
register_specialize(local_mul_specialize)
@gof.local_optimizer([T.mul])
@gof.local_optimizer([T.add])
def local_add_specialize(node):
def fill_chain(v):
return _fill_chain(v, node.inputs)
......
......@@ -185,14 +185,58 @@ class test_canonize(unittest.TestCase):
"""
verify that the Canonizer merge sequential Elemwise({mul,add})
"""
x, y, z = matrices('xyz')
for g,n in [
(x+y+z,1),
(x*y*z,1),
(x*y*(x+y+z),2),
]:
f = compile.function([x,y,z], g, mode=compile.Mode(optimizer='fast_run'))
assert(len(f.maker.env.toposort())==n)
shp=(5,5)
fx, fy, fz = fmatrices('xyz')
dx, dy, dz = dmatrices('xyz')
fv = fvector('r').dimshuffle('x',0)
fxv = numpy.asarray(numpy.random.rand(*shp),dtype='float32')
fyv = numpy.asarray(numpy.random.rand(*shp),dtype='float32')
fzv = numpy.asarray(numpy.random.rand(*shp),dtype='float32')
dxv = numpy.asarray(numpy.random.rand(*shp),dtype='float32')
dyv = numpy.asarray(numpy.random.rand(*shp),dtype='float32')
dzv = numpy.asarray(numpy.random.rand(*shp),dtype='float32')
fvv = numpy.asarray(numpy.random.rand(shp[0]),dtype='float32').reshape(1,shp[0])
cases = [
(fx+fy,(fx,fy),(fxv,fyv),1,'float32'),
(fx*fy,(fx,fy),(fxv,fyv),1,'float32'),
(fx+fy+fz,(fx,fy,fz),(fxv,fyv,fzv),1,'float32'),
(dx+dy+dz,(dx,dy,dz),(dxv,dyv,dzv),1,'float64'),
(fx*fy*fz,(fx,fy,fz),(fxv,fyv,fzv),1,'float32'),
(dx*dy*dz,(dx,dy,dz),(dxv,dyv,dzv),1,'float64'),
(fx*fy*(fx+fy+fz),(fx,fy,fz),(fxv,fyv,fzv),2,'float32'),
(dx*dy*(dx+dy+dz),(dx,dy,dz),(dxv,dyv,dzv),2,'float64'),
(fx*fy*(fx+fy+dz),(fx,fy,dz),(dxv,dyv,dzv),2,'float64'),#check mixed type add
(dz*fy*(fx+fy),(fx,fy,dz),(dxv,dyv,dzv),2,'float64'),#check mixed type mul
#check with dimshuffle of constant
(fx+fy+fz+2,(fx,fy,fz),(fxv,fyv,fzv),1,'float32'),
(fx*fy*fz*2,(fx,fy,fz),(fxv,fyv,fzv),1,'float32'),
(2+fx+fy+fz,(fx,fy,fz),(fxv,fyv,fzv),1,'float32'),
(2*fx*fy*fz,(fx,fy,fz),(fxv,fyv,fzv),1,'float32'),
(2+fx+fy+fz+2,(fx,fy,fz),(fxv,fyv,fzv),1,'float32'),
(2*fx*fy*fz*2,(fx,fy,fz),(fxv,fyv,fzv),1,'float32'),
(fx*fy*2*(fx+fy+fz),(fx,fy,fz),(fxv,fyv,fzv),2,'float32'),
(fx*fy*(2+fx+fy+fz),(fx,fy,fz),(fxv,fyv,fzv),2,'float32'),
(fx*fy*2*(fx+fy+fz+2),(fx,fy,fz),(fxv,fyv,fzv),2,'float32'),
#check with broadcast of row
(fx+fy+fz+fv,(fx,fy,fz,fv),(fxv,fyv,fzv,fvv),1,'float32'),
(fx*fy*fz*fv,(fx,fy,fz,fv),(fxv,fyv,fzv,fvv),1,'float32'),
(fv+fx+fy+fz,(fx,fy,fz,fv),(fxv,fyv,fzv,fvv),1,'float32'),
(fv*fx*fy*fz,(fx,fy,fz,fv),(fxv,fyv,fzv,fvv),1,'float32'),
(fx*fy*fv*(fx+fy+fz),(fx,fy,fz,fv),(fxv,fyv,fzv,fvv),2,'float32'),
(fx*fy*(fv+fx+fy+fz),(fx,fy,fz,fv),(fxv,fyv,fzv,fvv),2,'float32'),
(fx*fy*fv*(fv+fx+fy+fz),(fx,fy,fz,fv),(fxv,fyv,fzv,fvv),2,'float32'),
]#[10:11]
# print cases
for id, [g, sym_inputs, val_inputs, expected_out_nb_elemwise, out_dtype] in enumerate(cases):
f = compile.function(list(sym_inputs), g,
#we need the optimisation enabled, debug do this.
mode=compile.mode.predefined_modes['DEBUG_MODE'])
out = f(*val_inputs)
assert(len(f.maker.env.toposort())==expected_out_nb_elemwise)
assert(out_dtype==out.dtype)
def test_mixeddiv():
"""Test that int division is preserved"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论