提交 aa278955 authored 作者: James Bergstra's avatar James Bergstra

test_elemwise1 passes, didnt run in debugmode

上级 56d152c8
......@@ -4,3 +4,6 @@ from .var import (CudaNdarrayVariable,
CudaNdarrayConstant,
CudaNdarraySharedVariable,
shared_constructor)
import basic_ops
import opt
差异被折叠。
import sys
from theano.compile.sandbox.sharedvalue import shared
from theano.compile.sandbox.pfunc import pfunc
from theano import tensor
......@@ -11,7 +12,7 @@ def test_elemwise0():
a = tcn.shared_constructor(numpy.random.rand(4,4), 'a')
b = tensor.dmatrix()
b = tensor.fmatrix()
f = pfunc([b], [], updates=[(a, a+b)])
......@@ -27,11 +28,27 @@ def test_elemwise1():
""" Several kinds of elemwise expressions with no broadcasting, non power-of-two shape """
shape = (3,4)
a = tcn.shared_constructor(numpy.random.rand(*shape), 'a')
b = tensor.dmatrix()
f = pfunc([b], [], updates=[(a, a+b * tensor.exp(b**a))])
a = tcn.shared_constructor(numpy.random.rand(*shape)+0.5, 'a')
b = tensor.fmatrix()
#let debugmode catch any mistakes
f(numpy.ones(shape))
print >> sys.stderr, "STARTING FUNCTION 1"
f = pfunc([b], [], updates=[(a, b**a)])
for i, node in enumerate(f.maker.env.toposort()):
print i, node
f(numpy.random.rand(*shape)+0.3)
print >> sys.stderr, "STARTING FUNCTION 2"
#let debugmode catch any mistakes
f = pfunc([b], [], updates=[(a, tensor.exp(b**a))])
for i, node in enumerate(f.maker.env.toposort()):
print i, node
f(numpy.random.rand(*shape)+0.3)
print >> sys.stderr, "STARTING FUNCTION 3"
#let debugmode catch any mistakes
f = pfunc([b], [], updates=[(a, a+b * tensor.exp(b**a))])
f(numpy.random.rand(*shape)+0.3)
def test_elemwise2():
""" Several kinds of elemwise expressions with dimension permutations """
......@@ -41,6 +58,11 @@ def test_elemwise2():
b = tensor.Tensor(dtype='float32', broadcastable=[0]*len(shape))()
f = pfunc([b], [], updates=[(a, (a+b).dimshuffle([2,0,3,1]) *
tensor.exp(b**a).dimshuffle([2,0,3,1]))])
has_elemwise = False
for i, node in enumerate(f.maker.env.toposort()):
print i, node
has_elemwise = has_elemwise or isinstance(node.op, tensor.Elemwise)
assert not has_elemwise
#let debugmode catch errors
f(numpy.ones(shape))
......@@ -54,3 +76,4 @@ def test_elemwise3():
b**a).dimshuffle([2,0,3,1]))])
#let debugmode catch errors
f(numpy.ones(6))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论