提交 dd6bc255 authored 作者: Frederic's avatar Frederic 提交者: Melanie Ducoffe

Update to tests

上级 3f755f40
......@@ -5506,7 +5506,7 @@ class AllocEmpty(gof.Op):
sh = tuple([int(i) for i in inputs])
if out[0] is None or out[0].shape != sh:
# XXX: We could implement and call CudaNdarray.empty(sh) instead.
out[0] = numpy.empty(sh)
out[0] = numpy.empty(sh, dtype=self.dtype)
def do_merge(self, node):
return False
......
......@@ -878,25 +878,28 @@ def test_dot22scalar():
def check_dot22scalar(func, len_topo_scalar=-1):
topo = func.maker.fgraph.toposort()
ops = [x.op for x in topo]
classes = [type(x.op) for x in topo]
dtype4_upcast = theano.scalar.upcast(dtype4, dtype1,
dtype2)
if dtype1 == dtype2 == dtype3 == dtype4_upcast:
if len_topo_scalar > 0:
assert len(topo) == len_topo_scalar
assert _dot22scalar in ops, (dtype1, dtype2,
assert gemm_inplace in ops, (dtype1, dtype2,
dtype3, dtype4)
elif dtype1 == dtype2 == dtype4_upcast:
if not (len_topo_scalar > 0):
assert len(topo) == len_topo_scalar
assert _dot22scalar in ops, (dtype1, dtype2,
assert gemm_inplace in ops, (dtype1, dtype2,
dtype3, dtype4)
assert not T.Elemwise in classes, (
dtype1, dtype2, dtype3, dtype4)
else:
# Currently there is a problem of
# optimization order The constant get
# upcasted to float64 before we try to
# merge it with the dot22 of
# float32. So this prevent the merge.
assert _dot22scalar in ops or _dot22 in ops, (
assert gemm_inplace in ops or _dot22 in ops, (
dtype1, dtype2, dtype3, dtype4)
elif dtype1 == dtype2:
......@@ -925,7 +928,7 @@ def test_dot22scalar():
cst * c * T.dot(a, b),
mode=mode_blas_opt)
topo = f.maker.fgraph.toposort()
check_dot22scalar(f, 2)
check_dot22scalar(f, 5)
f(av, bv, cv)
......@@ -933,7 +936,7 @@ def test_dot22scalar():
c * cst * T.dot(a, b),
mode=mode_blas_opt)
topo = f.maker.fgraph.toposort()
check_dot22scalar(f, 2)
check_dot22scalar(f, 5)
f(av, bv, cv)
# Here, canonicalize also seems needed
......@@ -943,7 +946,7 @@ def test_dot22scalar():
cst2 * c * cst * T.dot(a, b),
mode=m2)
topo = f.maker.fgraph.toposort()
check_dot22scalar(f, 2)
check_dot22scalar(f, 5)
f(av, bv, cv)
if dtype1 == dtype2 == dtype3:
......@@ -951,7 +954,7 @@ def test_dot22scalar():
c * cst * a * T.dot(a, b),
mode=m2)
topo = f.maker.fgraph.toposort()
check_dot22scalar(f, 2)
check_dot22scalar(f, 5)
f(sv, sv, sv)
f = theano.function([a, b, c],
......@@ -974,7 +977,7 @@ def test_dot22scalar():
c * a * cst * T.dot(a, b),
mode=m2)
topo = f.maker.fgraph.toposort()
check_dot22scalar(f, 2)
check_dot22scalar(f, 5)
f(sv, sv, sv)
cmp((3, 4), (4, 5), (3, 5))
......@@ -994,7 +997,7 @@ def test_dot22scalar_cast():
for scalar_int_type in T.int_dtypes:
y = T.scalar(dtype=scalar_int_type)
f = theano.function([A, y], T.dot(A, A) * y, mode=mode_blas_opt)
assert _dot22scalar in [x.op for x in f.maker.fgraph.toposort()]
assert gemm_inplace in [x.op for x in f.maker.fgraph.toposort()]
A = T.fmatrix()
for scalar_int_type in T.int_dtypes:
y = T.scalar(dtype=scalar_int_type)
......@@ -1002,7 +1005,7 @@ def test_dot22scalar_cast():
if scalar_int_type in ['int32', 'int64']:
assert _dot22 in [x.op for x in f.maker.fgraph.toposort()]
else:
assert _dot22scalar in [x.op for x in f.maker.fgraph.toposort()]
assert gemm_inplace in [x.op for x in f.maker.fgraph.toposort()]
def test_local_dot22_to_dot22scalar():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论