提交 dd6bc255 authored 作者: Frederic's avatar Frederic 提交者: Melanie Ducoffe

Update to tests

上级 3f755f40
...@@ -5506,7 +5506,7 @@ class AllocEmpty(gof.Op): ...@@ -5506,7 +5506,7 @@ class AllocEmpty(gof.Op):
sh = tuple([int(i) for i in inputs]) sh = tuple([int(i) for i in inputs])
if out[0] is None or out[0].shape != sh: if out[0] is None or out[0].shape != sh:
# XXX: We could implement and call CudaNdarray.empty(sh) instead. # XXX: We could implement and call CudaNdarray.empty(sh) instead.
out[0] = numpy.empty(sh) out[0] = numpy.empty(sh, dtype=self.dtype)
def do_merge(self, node): def do_merge(self, node):
return False return False
......
...@@ -878,25 +878,28 @@ def test_dot22scalar(): ...@@ -878,25 +878,28 @@ def test_dot22scalar():
def check_dot22scalar(func, len_topo_scalar=-1): def check_dot22scalar(func, len_topo_scalar=-1):
topo = func.maker.fgraph.toposort() topo = func.maker.fgraph.toposort()
ops = [x.op for x in topo] ops = [x.op for x in topo]
classes = [type(x.op) for x in topo]
dtype4_upcast = theano.scalar.upcast(dtype4, dtype1, dtype4_upcast = theano.scalar.upcast(dtype4, dtype1,
dtype2) dtype2)
if dtype1 == dtype2 == dtype3 == dtype4_upcast: if dtype1 == dtype2 == dtype3 == dtype4_upcast:
if len_topo_scalar > 0: if len_topo_scalar > 0:
assert len(topo) == len_topo_scalar assert len(topo) == len_topo_scalar
assert _dot22scalar in ops, (dtype1, dtype2, assert gemm_inplace in ops, (dtype1, dtype2,
dtype3, dtype4) dtype3, dtype4)
elif dtype1 == dtype2 == dtype4_upcast: elif dtype1 == dtype2 == dtype4_upcast:
if not (len_topo_scalar > 0): if not (len_topo_scalar > 0):
assert len(topo) == len_topo_scalar assert len(topo) == len_topo_scalar
assert _dot22scalar in ops, (dtype1, dtype2, assert gemm_inplace in ops, (dtype1, dtype2,
dtype3, dtype4) dtype3, dtype4)
assert not T.Elemwise in classes, (
dtype1, dtype2, dtype3, dtype4)
else: else:
# Currently there is a problem of # Currently there is a problem of
# optimization order The constant get # optimization order The constant get
# upcasted to float64 before we try to # upcasted to float64 before we try to
# merge it with the dot22 of # merge it with the dot22 of
# float32. So this prevent the merge. # float32. So this prevent the merge.
assert _dot22scalar in ops or _dot22 in ops, ( assert gemm_inplace in ops or _dot22 in ops, (
dtype1, dtype2, dtype3, dtype4) dtype1, dtype2, dtype3, dtype4)
elif dtype1 == dtype2: elif dtype1 == dtype2:
...@@ -925,7 +928,7 @@ def test_dot22scalar(): ...@@ -925,7 +928,7 @@ def test_dot22scalar():
cst * c * T.dot(a, b), cst * c * T.dot(a, b),
mode=mode_blas_opt) mode=mode_blas_opt)
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
check_dot22scalar(f, 2) check_dot22scalar(f, 5)
f(av, bv, cv) f(av, bv, cv)
...@@ -933,7 +936,7 @@ def test_dot22scalar(): ...@@ -933,7 +936,7 @@ def test_dot22scalar():
c * cst * T.dot(a, b), c * cst * T.dot(a, b),
mode=mode_blas_opt) mode=mode_blas_opt)
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
check_dot22scalar(f, 2) check_dot22scalar(f, 5)
f(av, bv, cv) f(av, bv, cv)
# Here, canonicalize also seems needed # Here, canonicalize also seems needed
...@@ -943,7 +946,7 @@ def test_dot22scalar(): ...@@ -943,7 +946,7 @@ def test_dot22scalar():
cst2 * c * cst * T.dot(a, b), cst2 * c * cst * T.dot(a, b),
mode=m2) mode=m2)
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
check_dot22scalar(f, 2) check_dot22scalar(f, 5)
f(av, bv, cv) f(av, bv, cv)
if dtype1 == dtype2 == dtype3: if dtype1 == dtype2 == dtype3:
...@@ -951,7 +954,7 @@ def test_dot22scalar(): ...@@ -951,7 +954,7 @@ def test_dot22scalar():
c * cst * a * T.dot(a, b), c * cst * a * T.dot(a, b),
mode=m2) mode=m2)
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
check_dot22scalar(f, 2) check_dot22scalar(f, 5)
f(sv, sv, sv) f(sv, sv, sv)
f = theano.function([a, b, c], f = theano.function([a, b, c],
...@@ -974,7 +977,7 @@ def test_dot22scalar(): ...@@ -974,7 +977,7 @@ def test_dot22scalar():
c * a * cst * T.dot(a, b), c * a * cst * T.dot(a, b),
mode=m2) mode=m2)
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
check_dot22scalar(f, 2) check_dot22scalar(f, 5)
f(sv, sv, sv) f(sv, sv, sv)
cmp((3, 4), (4, 5), (3, 5)) cmp((3, 4), (4, 5), (3, 5))
...@@ -994,7 +997,7 @@ def test_dot22scalar_cast(): ...@@ -994,7 +997,7 @@ def test_dot22scalar_cast():
for scalar_int_type in T.int_dtypes: for scalar_int_type in T.int_dtypes:
y = T.scalar(dtype=scalar_int_type) y = T.scalar(dtype=scalar_int_type)
f = theano.function([A, y], T.dot(A, A) * y, mode=mode_blas_opt) f = theano.function([A, y], T.dot(A, A) * y, mode=mode_blas_opt)
assert _dot22scalar in [x.op for x in f.maker.fgraph.toposort()] assert gemm_inplace in [x.op for x in f.maker.fgraph.toposort()]
A = T.fmatrix() A = T.fmatrix()
for scalar_int_type in T.int_dtypes: for scalar_int_type in T.int_dtypes:
y = T.scalar(dtype=scalar_int_type) y = T.scalar(dtype=scalar_int_type)
...@@ -1002,7 +1005,7 @@ def test_dot22scalar_cast(): ...@@ -1002,7 +1005,7 @@ def test_dot22scalar_cast():
if scalar_int_type in ['int32', 'int64']: if scalar_int_type in ['int32', 'int64']:
assert _dot22 in [x.op for x in f.maker.fgraph.toposort()] assert _dot22 in [x.op for x in f.maker.fgraph.toposort()]
else: else:
assert _dot22scalar in [x.op for x in f.maker.fgraph.toposort()] assert gemm_inplace in [x.op for x in f.maker.fgraph.toposort()]
def test_local_dot22_to_dot22scalar(): def test_local_dot22_to_dot22scalar():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论