提交 f4f68658 authored 作者: Frederic's avatar Frederic

gemm opt speed up: small refactoring to call less fct.

上级 de33e6ad
...@@ -1379,27 +1379,25 @@ def _gemm_from_factored_list(lst): ...@@ -1379,27 +1379,25 @@ def _gemm_from_factored_list(lst):
"""Returns None, or a list to replace node.outputs """Returns None, or a list to replace node.outputs
""" """
# Make every pair in list have matching dtypes
# sM can be a tuple of 2 elements or a theano variable.
# We should not use __len__ as theano variables don't support
# it. I don't want to change this to isinstance(sM, tuple)
# as I'm not able to make a test that triggers this case.
def is_pair(sM):
try:
s, M = sM
return True
except Exception:
return False
lst2 = [] lst2 = []
# Remove the tuple that can't be cast correctly. # Remove the tuple that can't be cast correctly.
# This can happen when we try to cast a complex to a real # This can happen when we try to cast a complex to a real
for sM in lst: for sM in lst:
if is_pair(sM): # Make every pair in list have matching dtypes
# sM can be a tuple of 2 elements or a theano variable.
# We should not use __len__ as theano variables don't support
# it. I don't want to change this to isinstance(sM, tuple)
# as I'm not able to make a test that triggers this case.
try:
sm0, sm1 = sM sm0, sm1 = sM
sm0 = T.as_tensor_variable(sm0) sm0 = T.as_tensor_variable(sm0)
if theano.scalar.upcast(sm0.dtype, sm1.dtype) == sm1.dtype: if theano.scalar.upcast(sm0.dtype, sm1.dtype) == sm1.dtype:
lst2.append((T.cast(sm0, sm1.dtype), sM[1])) lst2.append((T.cast(sm0, sm1.dtype), sM[1]))
except ValueError:
# "ValueError: length not known" is raised by
# "sm0, sm1 = sM" when sM is a Theano variable
pass
lst = lst2 lst = lst2
# Try every pair in the sM_list, trying to turn it into a gemm operation # Try every pair in the sM_list, trying to turn it into a gemm operation
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论