提交 628b35b4 authored 作者: James Bergstra's avatar James Bergstra

merge

...@@ -463,7 +463,12 @@ def _is_function_output(node): ...@@ -463,7 +463,12 @@ def _is_function_output(node):
def _is_used_in_graph(node): def _is_used_in_graph(node):
return not(_is_function_output(node) or node.clients==[]) return not(_is_function_output(node) or node.clients==[])
def _check_strides_match(a, b, raise_on_err, op): def _check_strides_match(a, b, warn_err, op):
"""
param: warn_err: if 0, no warning, if 1 warning, if 2 error
"""
if warn_err==0: return
try: try:
strides_eq = a.strides == b.strides strides_eq = a.strides == b.strides
except: except:
...@@ -471,7 +476,7 @@ def _check_strides_match(a, b, raise_on_err, op): ...@@ -471,7 +476,7 @@ def _check_strides_match(a, b, raise_on_err, op):
if not strides_eq: if not strides_eq:
e = TypeError('Stride mismatch', (a.shape, b.shape, a.strides, b.strides, str(op))) e = TypeError('Stride mismatch', (a.shape, b.shape, a.strides, b.strides, str(op)))
if raise_on_err: if warn_err==2:
raise e raise e
else: else:
print >> sys.stderr, 'WARNING:', e print >> sys.stderr, 'WARNING:', e
...@@ -1370,10 +1375,10 @@ class DebugMode(Mode): ...@@ -1370,10 +1375,10 @@ class DebugMode(Mode):
Should we check for (and complain about) NaN/Inf ndarray elements? Should we check for (and complain about) NaN/Inf ndarray elements?
""" """
require_matching_strides = bool(int(os.getenv('THEANO_DEBUGMODE_CHECK_STRIDES', 0))) require_matching_strides = bool(int(os.getenv('THEANO_DEBUGMODE_CHECK_STRIDES', 1)))
""" """
Should we check for (and complain about) Ops whose python and C outputs are ndarrays with Should we check for (and complain about) Ops whose python and C outputs are ndarrays with
different strides? (This can catch bugs, but is generally overly strict.) different strides? (This can catch bugs, but is generally overly strict.) 0 no check, 1 warn, 2 err.
""" """
# This function will be used to create a FunctionMaker in # This function will be used to create a FunctionMaker in
......
...@@ -2570,6 +2570,8 @@ def tile(x, reps, ndim=None): ...@@ -2570,6 +2570,8 @@ def tile(x, reps, ndim=None):
class Dot(Op): class Dot(Op):
"""Compute matrix-matrix, matrix-vector products and vector inner-products. """Compute matrix-matrix, matrix-vector products and vector inner-products.
:note: matrix-matrix products are sometimes optimized to Dot22 ops (see tensor.blas)
""" """
def make_node(self, *inputs): def make_node(self, *inputs):
inputs = map(as_tensor_variable, inputs) inputs = map(as_tensor_variable, inputs)
...@@ -2613,6 +2615,8 @@ class Dot(Op): ...@@ -2613,6 +2615,8 @@ class Dot(Op):
def perform(self, node, (x, y), (z, )): def perform(self, node, (x, y), (z, )):
try: try:
# the asarray is here because dot between two vectors gives a numpy float object
# but we need to return a 0d ndarray
z[0] = numpy.asarray(numpy.dot(x, y)) z[0] = numpy.asarray(numpy.dot(x, y))
except ValueError, e: except ValueError, e:
# The error raised by numpy has no shape information, we mean to add that # The error raised by numpy has no shape information, we mean to add that
......
...@@ -10,7 +10,6 @@ from elemwise import Elemwise, DimShuffle ...@@ -10,7 +10,6 @@ from elemwise import Elemwise, DimShuffle
from theano import scalar from theano import scalar
import basic as T import basic as T
import inplace as I import inplace as I
import numpy
import numpy as N import numpy as N
import operator import operator
import itertools import itertools
...@@ -873,9 +872,9 @@ def local_mul_zero(node): ...@@ -873,9 +872,9 @@ def local_mul_zero(node):
except TypeError: except TypeError:
continue continue
#print 'MUL by value', value, node.inputs #print 'MUL by value', value, node.inputs
if numpy.all(value == 0): if N.all(value == 0):
#print '... returning zeros' #print '... returning zeros'
return _fill_chain(numpy.asarray(0, dtype=otype.dtype), node.inputs) return _fill_chain(N.asarray(0, dtype=otype.dtype), node.inputs)
register_canonicalize(local_mul_zero) register_canonicalize(local_mul_zero)
@gof.local_optimizer([T.true_div]) @gof.local_optimizer([T.true_div])
...@@ -982,7 +981,7 @@ def local_mul_specialize(node): ...@@ -982,7 +981,7 @@ def local_mul_specialize(node):
return False return False
register_specialize(local_mul_specialize) register_specialize(local_mul_specialize)
@gof.local_optimizer([T.mul]) @gof.local_optimizer([T.add])
def local_add_specialize(node): def local_add_specialize(node):
def fill_chain(v): def fill_chain(v):
return _fill_chain(v, node.inputs) return _fill_chain(v, node.inputs)
......
...@@ -185,14 +185,58 @@ class test_canonize(unittest.TestCase): ...@@ -185,14 +185,58 @@ class test_canonize(unittest.TestCase):
""" """
verify that the Canonizer merge sequential Elemwise({mul,add}) verify that the Canonizer merge sequential Elemwise({mul,add})
""" """
x, y, z = matrices('xyz') shp=(5,5)
for g,n in [ fx, fy, fz = fmatrices('xyz')
(x+y+z,1), dx, dy, dz = dmatrices('xyz')
(x*y*z,1), fv = fvector('r').dimshuffle('x',0)
(x*y*(x+y+z),2), fxv = numpy.asarray(numpy.random.rand(*shp),dtype='float32')
]: fyv = numpy.asarray(numpy.random.rand(*shp),dtype='float32')
f = compile.function([x,y,z], g, mode=compile.Mode(optimizer='fast_run')) fzv = numpy.asarray(numpy.random.rand(*shp),dtype='float32')
assert(len(f.maker.env.toposort())==n) dxv = numpy.asarray(numpy.random.rand(*shp),dtype='float32')
dyv = numpy.asarray(numpy.random.rand(*shp),dtype='float32')
dzv = numpy.asarray(numpy.random.rand(*shp),dtype='float32')
fvv = numpy.asarray(numpy.random.rand(shp[0]),dtype='float32').reshape(1,shp[0])
cases = [
(fx+fy,(fx,fy),(fxv,fyv),1,'float32'),
(fx*fy,(fx,fy),(fxv,fyv),1,'float32'),
(fx+fy+fz,(fx,fy,fz),(fxv,fyv,fzv),1,'float32'),
(dx+dy+dz,(dx,dy,dz),(dxv,dyv,dzv),1,'float64'),
(fx*fy*fz,(fx,fy,fz),(fxv,fyv,fzv),1,'float32'),
(dx*dy*dz,(dx,dy,dz),(dxv,dyv,dzv),1,'float64'),
(fx*fy*(fx+fy+fz),(fx,fy,fz),(fxv,fyv,fzv),2,'float32'),
(dx*dy*(dx+dy+dz),(dx,dy,dz),(dxv,dyv,dzv),2,'float64'),
(fx*fy*(fx+fy+dz),(fx,fy,dz),(dxv,dyv,dzv),2,'float64'),#check mixed type add
(dz*fy*(fx+fy),(fx,fy,dz),(dxv,dyv,dzv),2,'float64'),#check mixed type mul
#check with dimshuffle of constant
(fx+fy+fz+2,(fx,fy,fz),(fxv,fyv,fzv),1,'float32'),
(fx*fy*fz*2,(fx,fy,fz),(fxv,fyv,fzv),1,'float32'),
(2+fx+fy+fz,(fx,fy,fz),(fxv,fyv,fzv),1,'float32'),
(2*fx*fy*fz,(fx,fy,fz),(fxv,fyv,fzv),1,'float32'),
(2+fx+fy+fz+2,(fx,fy,fz),(fxv,fyv,fzv),1,'float32'),
(2*fx*fy*fz*2,(fx,fy,fz),(fxv,fyv,fzv),1,'float32'),
(fx*fy*2*(fx+fy+fz),(fx,fy,fz),(fxv,fyv,fzv),2,'float32'),
(fx*fy*(2+fx+fy+fz),(fx,fy,fz),(fxv,fyv,fzv),2,'float32'),
(fx*fy*2*(fx+fy+fz+2),(fx,fy,fz),(fxv,fyv,fzv),2,'float32'),
#check with broadcast of row
(fx+fy+fz+fv,(fx,fy,fz,fv),(fxv,fyv,fzv,fvv),1,'float32'),
(fx*fy*fz*fv,(fx,fy,fz,fv),(fxv,fyv,fzv,fvv),1,'float32'),
(fv+fx+fy+fz,(fx,fy,fz,fv),(fxv,fyv,fzv,fvv),1,'float32'),
(fv*fx*fy*fz,(fx,fy,fz,fv),(fxv,fyv,fzv,fvv),1,'float32'),
(fx*fy*fv*(fx+fy+fz),(fx,fy,fz,fv),(fxv,fyv,fzv,fvv),2,'float32'),
(fx*fy*(fv+fx+fy+fz),(fx,fy,fz,fv),(fxv,fyv,fzv,fvv),2,'float32'),
(fx*fy*fv*(fv+fx+fy+fz),(fx,fy,fz,fv),(fxv,fyv,fzv,fvv),2,'float32'),
]#[10:11]
# print cases
for id, [g, sym_inputs, val_inputs, expected_out_nb_elemwise, out_dtype] in enumerate(cases):
f = compile.function(list(sym_inputs), g,
#we need the optimisation enabled, debug do this.
mode=compile.mode.predefined_modes['DEBUG_MODE'])
out = f(*val_inputs)
assert(len(f.maker.env.toposort())==expected_out_nb_elemwise)
assert(out_dtype==out.dtype)
def test_mixeddiv(): def test_mixeddiv():
"""Test that int division is preserved""" """Test that int division is preserved"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论