提交 df4e0d80 authored 作者: Frederic's avatar Frederic

Readd the previous check with is_pair as I'm not able to add a test for it.

上级 92021ca6
...@@ -1113,11 +1113,23 @@ def _gemm_from_factored_list(lst): ...@@ -1113,11 +1113,23 @@ def _gemm_from_factored_list(lst):
"""Returns None, or a list to replace node.outputs """Returns None, or a list to replace node.outputs
""" """
# Make every pair in list have matching dtypes
# sM can be a tuple of 2 elements of a theano variable
# We should not use __len__ as the theano variable don't support
# it. I don't want to change this to ininstance(sM, tuple)
# as I'm not able to make a test that triger
def is_pair(sM):
try:
s, M = sM
return True
except Exception:
return False
lst2 = [] lst2 = []
# Remove the tuple that can't be casted correctly. # Remove the tuple that can't be casted correctly.
# This can happen when we try to cast a complex to a real # This can happen when we try to cast a complex to a real
for sM in lst: for sM in lst:
if len(sM) == 2: if is_pair(sM):
sm0, sm1 = sM sm0, sm1 = sM
sm0 = T.as_tensor_variable(sm0) sm0 = T.as_tensor_variable(sm0)
if theano.scalar.upcast(sm0.dtype, sm1.dtype) == sm1.dtype: if theano.scalar.upcast(sm0.dtype, sm1.dtype) == sm1.dtype:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论