提交 a4e182df authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #2970 from carriepl/scan_memory_usage

[Bug] Revert dot22->gemm optimization back to dot22->dot22scalar
......@@ -570,8 +570,6 @@ def local_gpu_dot22(node):
@local_optimizer([gpu_from_host, tensor.blas.Dot22Scalar])
def local_gpu_dot22scalar(node):
"""
Deprecated : _dot22scalar has been replace by gemm
see Dot22scalar for more details
gpu_from_host(dot22scalar) -> gpudot(gpu_from_host)
dot(host_from_gpu) -> host_from_gpu(gpudot22scalar)
......
......@@ -1984,6 +1984,13 @@ _dot22scalar = Dot22Scalar()
@local_optimizer([T.mul])
def local_dot22_to_dot22scalar(node):
"""
:note: Previous attempts to alter this optimization to replace dot22 with
gemm instead of dot22scalar resulted in some Scan nodes being
duplicated and the ScanSaveMem optimization never running on them,
resulting in highly increased memory usage. Until this issue is
resolved, this optimization should keep using dot22scalar instead of
gemm.
:note: we upcast the scalar if after the multiplication with the
dot this give the same type.
......@@ -2043,11 +2050,7 @@ def local_dot22_to_dot22scalar(node):
a = T.cast(_as_scalar(m.owner.inputs[scalar_idx],
dtype=d.dtype), d.type.dtype)
assert not a.type.ndim
z = T.AllocEmpty(d.owner.inputs[0].dtype)(d.owner.inputs[0].shape[0],
d.owner.inputs[1].shape[1])
zero = T.as_tensor_variable(numpy.asarray(0, dtype=a.dtype))
dot = gemm(z, a, d.owner.inputs[0], d.owner.inputs[1], zero)
dot = _dot22scalar(d.owner.inputs[0], d.owner.inputs[1], a)
# The other inputs to the original node that were
# neither part of the dot22 or this mul should be
......@@ -2083,16 +2086,12 @@ def local_dot22_to_dot22scalar(node):
a = T.cast(i_scalar[scalar_idx], d.type.dtype)
assert not a.type.ndim
if len(o) == 0:
z = T.AllocEmpty(d.owner.inputs[0].dtype)(d.owner.inputs[0].shape[0],
d.owner.inputs[1].shape[1])
zero = T.as_tensor_variable(numpy.asarray(0, dtype=a.dtype))
return [gemm(z, a, d.owner.inputs[0], d.owner.inputs[1], zero)]
return [_dot22scalar(d.owner.inputs[0], d.owner.inputs[1], a)]
else:
z = T.AllocEmpty(d.owner.inputs[0].dtype)(d.owner.inputs[0].shape[0],
d.owner.inputs[1].shape[1])
zero = T.as_tensor_variable(numpy.asarray(0, dtype=a.dtype))
return [T.mul(gemm(z, a, d.owner.inputs[0], d.owner.inputs[1],
zero), *o)]
return [T.mul(_dot22scalar(d.owner.inputs[0],
d.owner.inputs[1], a), *o)]
# must happen after gemm as the gemm optimizer don't understant
# dot22scalar and gemm give more speed up then dot22scalar
blas_optdb.register('local_dot22_to_dot22scalar',
......
......@@ -875,7 +875,7 @@ def test_dot22scalar():
cst = theano.tensor.basic.constant(.2, dtype=dtype4)
cst2 = theano.tensor.basic.constant(.1, dtype=dtype4)
def check_dot22scalar_gemm(func, len_topo_scalar=-1):
def check_dot22scalar(func, len_topo_scalar=-1):
topo = func.maker.fgraph.toposort()
ops = [x.op for x in topo]
classes = [type(x.op) for x in topo]
......@@ -885,22 +885,20 @@ def test_dot22scalar():
if dtype1 == dtype2 == dtype3 == dtype4_upcast:
if len_topo_scalar > 0:
assert len(topo) == len_topo_scalar
assert gemm_inplace in ops, (dtype1, dtype2,
assert _dot22scalar in ops, (dtype1, dtype2,
dtype3, dtype4)
elif dtype1 == dtype2 == dtype4_upcast:
if not (len_topo_scalar > 0):
assert len(topo) == len_topo_scalar
assert gemm_inplace in ops, (dtype1, dtype2,
assert _dot22scalar in ops, (dtype1, dtype2,
dtype3, dtype4)
assert not T.Elemwise in classes, (
dtype1, dtype2, dtype3, dtype4)
else:
# Currently there is a problem of
# optimization order The constant get
# upcasted to float64 before we try to
# merge it with the dot22 of
# float32. So this prevent the merge.
assert gemm_inplace in ops or _dot22 in ops, (
assert _dot22scalar in ops or _dot22 in ops, (
dtype1, dtype2, dtype3, dtype4)
elif dtype1 == dtype2:
......@@ -920,7 +918,7 @@ def test_dot22scalar():
f = theano.function([a, b], cst * T.dot(a, b),
mode=mode_blas_opt)
topo = f.maker.fgraph.toposort()
check_dot22scalar_gemm(f, 1)
check_dot22scalar(f, 1)
f(av, bv)
......@@ -929,8 +927,7 @@ def test_dot22scalar():
cst * c * T.dot(a, b),
mode=mode_blas_opt)
topo = f.maker.fgraph.toposort()
check_dot22scalar_gemm(f, 5)
#print (av.dtype, bv.dtype, cv.dtype)
check_dot22scalar(f, 2)
f(av, bv, cv)
......@@ -938,7 +935,7 @@ def test_dot22scalar():
c * cst * T.dot(a, b),
mode=mode_blas_opt)
topo = f.maker.fgraph.toposort()
check_dot22scalar_gemm(f, 5)
check_dot22scalar(f, 2)
f(av, bv, cv)
# Here, canonicalize also seems needed
......@@ -948,7 +945,7 @@ def test_dot22scalar():
cst2 * c * cst * T.dot(a, b),
mode=m2)
topo = f.maker.fgraph.toposort()
check_dot22scalar_gemm(f, 5)
check_dot22scalar(f, 2)
f(av, bv, cv)
if dtype1 == dtype2 == dtype3:
......@@ -956,7 +953,7 @@ def test_dot22scalar():
c * cst * a * T.dot(a, b),
mode=m2)
topo = f.maker.fgraph.toposort()
check_dot22scalar_gemm(f, 5)
check_dot22scalar(f, 2)
f(sv, sv, sv)
f = theano.function([a, b, c],
......@@ -979,7 +976,7 @@ def test_dot22scalar():
c * a * cst * T.dot(a, b),
mode=m2)
topo = f.maker.fgraph.toposort()
check_dot22scalar_gemm(f, 5)
check_dot22scalar(f, 2)
f(sv, sv, sv)
cmp((3, 4), (4, 5), (3, 5))
......@@ -999,7 +996,7 @@ def test_dot22scalar_cast():
for scalar_int_type in T.int_dtypes:
y = T.scalar(dtype=scalar_int_type)
f = theano.function([A, y], T.dot(A, A) * y, mode=mode_blas_opt)
assert gemm_inplace in [x.op for x in f.maker.fgraph.toposort()]
assert _dot22scalar in [x.op for x in f.maker.fgraph.toposort()]
A = T.fmatrix()
for scalar_int_type in T.int_dtypes:
y = T.scalar(dtype=scalar_int_type)
......@@ -1007,7 +1004,7 @@ def test_dot22scalar_cast():
if scalar_int_type in ['int32', 'int64']:
assert _dot22 in [x.op for x in f.maker.fgraph.toposort()]
else:
assert gemm_inplace in [x.op for x in f.maker.fgraph.toposort()]
assert _dot22scalar in [x.op for x in f.maker.fgraph.toposort()]
def test_local_dot22_to_dot22scalar():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论