提交 fe306e3f authored 作者: --global's avatar --global

Update tests in tensor/tests/test_blas.py

上级 e01b285e
......@@ -875,7 +875,7 @@ def test_dot22scalar():
cst = theano.tensor.basic.constant(.2, dtype=dtype4)
cst2 = theano.tensor.basic.constant(.1, dtype=dtype4)
def check_dot22scalar_gemm(func, len_topo_scalar=-1):
def check_dot22scalar(func, len_topo_scalar=-1):
topo = func.maker.fgraph.toposort()
ops = [x.op for x in topo]
classes = [type(x.op) for x in topo]
......@@ -885,22 +885,20 @@ def test_dot22scalar():
if dtype1 == dtype2 == dtype3 == dtype4_upcast:
if len_topo_scalar > 0:
assert len(topo) == len_topo_scalar
assert gemm_inplace in ops, (dtype1, dtype2,
assert _dot22scalar in ops, (dtype1, dtype2,
dtype3, dtype4)
elif dtype1 == dtype2 == dtype4_upcast:
if not (len_topo_scalar > 0):
assert len(topo) == len_topo_scalar
assert gemm_inplace in ops, (dtype1, dtype2,
assert _dot22scalar in ops, (dtype1, dtype2,
dtype3, dtype4)
assert not T.Elemwise in classes, (
dtype1, dtype2, dtype3, dtype4)
else:
# Currently there is a problem of
# optimization order The constant get
# upcasted to float64 before we try to
# merge it with the dot22 of
# float32. So this prevent the merge.
assert gemm_inplace in ops or _dot22 in ops, (
assert _dot22scalar in ops or _dot22 in ops, (
dtype1, dtype2, dtype3, dtype4)
elif dtype1 == dtype2:
......@@ -920,7 +918,7 @@ def test_dot22scalar():
f = theano.function([a, b], cst * T.dot(a, b),
mode=mode_blas_opt)
topo = f.maker.fgraph.toposort()
check_dot22scalar_gemm(f, 1)
check_dot22scalar(f, 1)
f(av, bv)
......@@ -929,8 +927,7 @@ def test_dot22scalar():
cst * c * T.dot(a, b),
mode=mode_blas_opt)
topo = f.maker.fgraph.toposort()
check_dot22scalar_gemm(f, 5)
#print (av.dtype, bv.dtype, cv.dtype)
check_dot22scalar(f, 2)
f(av, bv, cv)
......@@ -938,7 +935,7 @@ def test_dot22scalar():
c * cst * T.dot(a, b),
mode=mode_blas_opt)
topo = f.maker.fgraph.toposort()
check_dot22scalar_gemm(f, 5)
check_dot22scalar(f, 2)
f(av, bv, cv)
# Here, canonicalize also seems needed
......@@ -948,7 +945,7 @@ def test_dot22scalar():
cst2 * c * cst * T.dot(a, b),
mode=m2)
topo = f.maker.fgraph.toposort()
check_dot22scalar_gemm(f, 5)
check_dot22scalar(f, 2)
f(av, bv, cv)
if dtype1 == dtype2 == dtype3:
......@@ -956,7 +953,7 @@ def test_dot22scalar():
c * cst * a * T.dot(a, b),
mode=m2)
topo = f.maker.fgraph.toposort()
check_dot22scalar_gemm(f, 5)
check_dot22scalar(f, 2)
f(sv, sv, sv)
f = theano.function([a, b, c],
......@@ -979,7 +976,7 @@ def test_dot22scalar():
c * a * cst * T.dot(a, b),
mode=m2)
topo = f.maker.fgraph.toposort()
check_dot22scalar_gemm(f, 5)
check_dot22scalar(f, 2)
f(sv, sv, sv)
cmp((3, 4), (4, 5), (3, 5))
......@@ -999,7 +996,7 @@ def test_dot22scalar_cast():
for scalar_int_type in T.int_dtypes:
y = T.scalar(dtype=scalar_int_type)
f = theano.function([A, y], T.dot(A, A) * y, mode=mode_blas_opt)
assert gemm_inplace in [x.op for x in f.maker.fgraph.toposort()]
assert _dot22scalar in [x.op for x in f.maker.fgraph.toposort()]
A = T.fmatrix()
for scalar_int_type in T.int_dtypes:
y = T.scalar(dtype=scalar_int_type)
......@@ -1007,7 +1004,7 @@ def test_dot22scalar_cast():
if scalar_int_type in ['int32', 'int64']:
assert _dot22 in [x.op for x in f.maker.fgraph.toposort()]
else:
assert gemm_inplace in [x.op for x in f.maker.fgraph.toposort()]
assert _dot22scalar in [x.op for x in f.maker.fgraph.toposort()]
def test_local_dot22_to_dot22scalar():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论