提交 a02f01cf authored 作者: James Bergstra's avatar James Bergstra

Added test for many broadcastable patterns with dot

上级 24401f1c
...@@ -1604,6 +1604,54 @@ class t_dot(unittest.TestCase): ...@@ -1604,6 +1604,54 @@ class t_dot(unittest.TestCase):
#utt.verify_grad(dot, [self.rand(), self.rand(2)]) #utt.verify_grad(dot, [self.rand(), self.rand(2)])
#utt.verify_grad(dot, [self.rand(), self.rand(2,5)]) #utt.verify_grad(dot, [self.rand(), self.rand(2,5)])
def test_broadcastable_patterns(self):
#
# These examples hsould all work because we broadcastable or no, all dimensions of all
# results have size 1.
#
def val_for(r):
if r.ndim == 0:
return numpy.asarray(1.1, dtype=r.dtype)
if r.ndim == 1:
return numpy.asarray([1.2], dtype=r.dtype)
elif r.ndim == 2:
return numpy.asarray([[1.3]], dtype=r.dtype)
raise ValueError()
failures = []
for dtype0 in ('float32', 'float64', 'complex64', 'complex128'):
for dtype1 in ('float32', 'float64', 'complex64', 'complex128'):
for bc0 in ((True,), (False,), (True, True), (True, False), (False, True),
(False,False)):
for bc1 in ((True,), (False,), (True, True), (True, False), (False, True),
(False,False)):
x = TensorType(dtype=dtype0, broadcastable=bc0)()
y = TensorType(dtype=dtype1, broadcastable=bc1)()
z = dot(x,y)
t = TensorType(dtype=dtype0, broadcastable=z.broadcastable)()
rval = z * 3 + 2*t
if rval.type.dtype.startswith('complex'):
# there is a problem with complex numbers right now
# Elemwise code doesn't compile when both precisions of complex
# numbers are used in the same file because the operators
# aren't declared properly.
failures.append((dtype0, dtype1, bc0, bc1))
continue
f = function([x,y,t], rval)
xval = val_for(x)
yval = val_for(y)
tval = val_for(t)
f(xval, yval, tval) #debugmode checks result
#if failures:
#print failures
assert not failures
class T_tensorfromscalar(unittest.TestCase): class T_tensorfromscalar(unittest.TestCase):
def test0(self): def test0(self):
s = scal.constant(56) s = scal.constant(56)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论