提交 fe306e3f authored 作者: --global's avatar --global

Update tests in tensor/tests/test_blas.py

上级 e01b285e
...@@ -875,7 +875,7 @@ def test_dot22scalar(): ...@@ -875,7 +875,7 @@ def test_dot22scalar():
cst = theano.tensor.basic.constant(.2, dtype=dtype4) cst = theano.tensor.basic.constant(.2, dtype=dtype4)
cst2 = theano.tensor.basic.constant(.1, dtype=dtype4) cst2 = theano.tensor.basic.constant(.1, dtype=dtype4)
def check_dot22scalar_gemm(func, len_topo_scalar=-1): def check_dot22scalar(func, len_topo_scalar=-1):
topo = func.maker.fgraph.toposort() topo = func.maker.fgraph.toposort()
ops = [x.op for x in topo] ops = [x.op for x in topo]
classes = [type(x.op) for x in topo] classes = [type(x.op) for x in topo]
...@@ -885,22 +885,20 @@ def test_dot22scalar(): ...@@ -885,22 +885,20 @@ def test_dot22scalar():
if dtype1 == dtype2 == dtype3 == dtype4_upcast: if dtype1 == dtype2 == dtype3 == dtype4_upcast:
if len_topo_scalar > 0: if len_topo_scalar > 0:
assert len(topo) == len_topo_scalar assert len(topo) == len_topo_scalar
assert gemm_inplace in ops, (dtype1, dtype2, assert _dot22scalar in ops, (dtype1, dtype2,
dtype3, dtype4) dtype3, dtype4)
elif dtype1 == dtype2 == dtype4_upcast: elif dtype1 == dtype2 == dtype4_upcast:
if not (len_topo_scalar > 0): if not (len_topo_scalar > 0):
assert len(topo) == len_topo_scalar assert len(topo) == len_topo_scalar
assert gemm_inplace in ops, (dtype1, dtype2, assert _dot22scalar in ops, (dtype1, dtype2,
dtype3, dtype4) dtype3, dtype4)
assert not T.Elemwise in classes, (
dtype1, dtype2, dtype3, dtype4)
else: else:
# Currently there is a problem of # Currently there is a problem of
# optimization order The constant get # optimization order The constant get
# upcasted to float64 before we try to # upcasted to float64 before we try to
# merge it with the dot22 of # merge it with the dot22 of
# float32. So this prevent the merge. # float32. So this prevent the merge.
assert gemm_inplace in ops or _dot22 in ops, ( assert _dot22scalar in ops or _dot22 in ops, (
dtype1, dtype2, dtype3, dtype4) dtype1, dtype2, dtype3, dtype4)
elif dtype1 == dtype2: elif dtype1 == dtype2:
...@@ -920,7 +918,7 @@ def test_dot22scalar(): ...@@ -920,7 +918,7 @@ def test_dot22scalar():
f = theano.function([a, b], cst * T.dot(a, b), f = theano.function([a, b], cst * T.dot(a, b),
mode=mode_blas_opt) mode=mode_blas_opt)
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
check_dot22scalar_gemm(f, 1) check_dot22scalar(f, 1)
f(av, bv) f(av, bv)
...@@ -929,8 +927,7 @@ def test_dot22scalar(): ...@@ -929,8 +927,7 @@ def test_dot22scalar():
cst * c * T.dot(a, b), cst * c * T.dot(a, b),
mode=mode_blas_opt) mode=mode_blas_opt)
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
check_dot22scalar_gemm(f, 5) check_dot22scalar(f, 2)
#print (av.dtype, bv.dtype, cv.dtype)
f(av, bv, cv) f(av, bv, cv)
...@@ -938,7 +935,7 @@ def test_dot22scalar(): ...@@ -938,7 +935,7 @@ def test_dot22scalar():
c * cst * T.dot(a, b), c * cst * T.dot(a, b),
mode=mode_blas_opt) mode=mode_blas_opt)
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
check_dot22scalar_gemm(f, 5) check_dot22scalar(f, 2)
f(av, bv, cv) f(av, bv, cv)
# Here, canonicalize also seems needed # Here, canonicalize also seems needed
...@@ -948,7 +945,7 @@ def test_dot22scalar(): ...@@ -948,7 +945,7 @@ def test_dot22scalar():
cst2 * c * cst * T.dot(a, b), cst2 * c * cst * T.dot(a, b),
mode=m2) mode=m2)
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
check_dot22scalar_gemm(f, 5) check_dot22scalar(f, 2)
f(av, bv, cv) f(av, bv, cv)
if dtype1 == dtype2 == dtype3: if dtype1 == dtype2 == dtype3:
...@@ -956,7 +953,7 @@ def test_dot22scalar(): ...@@ -956,7 +953,7 @@ def test_dot22scalar():
c * cst * a * T.dot(a, b), c * cst * a * T.dot(a, b),
mode=m2) mode=m2)
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
check_dot22scalar_gemm(f, 5) check_dot22scalar(f, 2)
f(sv, sv, sv) f(sv, sv, sv)
f = theano.function([a, b, c], f = theano.function([a, b, c],
...@@ -979,7 +976,7 @@ def test_dot22scalar(): ...@@ -979,7 +976,7 @@ def test_dot22scalar():
c * a * cst * T.dot(a, b), c * a * cst * T.dot(a, b),
mode=m2) mode=m2)
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
check_dot22scalar_gemm(f, 5) check_dot22scalar(f, 2)
f(sv, sv, sv) f(sv, sv, sv)
cmp((3, 4), (4, 5), (3, 5)) cmp((3, 4), (4, 5), (3, 5))
...@@ -999,7 +996,7 @@ def test_dot22scalar_cast(): ...@@ -999,7 +996,7 @@ def test_dot22scalar_cast():
for scalar_int_type in T.int_dtypes: for scalar_int_type in T.int_dtypes:
y = T.scalar(dtype=scalar_int_type) y = T.scalar(dtype=scalar_int_type)
f = theano.function([A, y], T.dot(A, A) * y, mode=mode_blas_opt) f = theano.function([A, y], T.dot(A, A) * y, mode=mode_blas_opt)
assert gemm_inplace in [x.op for x in f.maker.fgraph.toposort()] assert _dot22scalar in [x.op for x in f.maker.fgraph.toposort()]
A = T.fmatrix() A = T.fmatrix()
for scalar_int_type in T.int_dtypes: for scalar_int_type in T.int_dtypes:
y = T.scalar(dtype=scalar_int_type) y = T.scalar(dtype=scalar_int_type)
...@@ -1007,7 +1004,7 @@ def test_dot22scalar_cast(): ...@@ -1007,7 +1004,7 @@ def test_dot22scalar_cast():
if scalar_int_type in ['int32', 'int64']: if scalar_int_type in ['int32', 'int64']:
assert _dot22 in [x.op for x in f.maker.fgraph.toposort()] assert _dot22 in [x.op for x in f.maker.fgraph.toposort()]
else: else:
assert gemm_inplace in [x.op for x in f.maker.fgraph.toposort()] assert _dot22scalar in [x.op for x in f.maker.fgraph.toposort()]
def test_local_dot22_to_dot22scalar(): def test_local_dot22_to_dot22scalar():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论