提交 c7653446 authored 作者: Jesse Livezey's avatar Jesse Livezey

flake

上级 12b63cdc
......@@ -76,32 +76,32 @@ class BaseCorrMM(gof.Op):
str(self.filter_dilation))
def cast(self, in1, in2):
"""
Upcast input variables if neccesary.
"""
float_types = ['float32', 'float64']
dtype_1 = in1.type.dtype
dtype_2 = in2.type.dtype
assert dtype_1 in float_types
assert dtype_2 in float_types
if dtype_1 == 'float64' or dtype_2 == 'float64':
dtype_o = 'float64'
else:
dtype_o = 'float32'
if dtype_1 != dtype_o:
out1 = in1.astype(dtype_o)
else:
out1 = in1
if dtype_2 != dtype_o:
out2 = in2.astype(dtype_o)
else:
out2 = in2
return out1, out2
"""
Upcast input variables if neccesary.
"""
float_types = ['float32', 'float64']
dtype_1 = in1.type.dtype
dtype_2 = in2.type.dtype
assert dtype_1 in float_types
assert dtype_2 in float_types
if dtype_1 == 'float64' or dtype_2 == 'float64':
dtype_o = 'float64'
else:
dtype_o = 'float32'
if dtype_1 != dtype_o:
out1 = in1.astype(dtype_o)
else:
out1 = in1
if dtype_2 != dtype_o:
out2 = in2.astype(dtype_o)
else:
out2 = in2
return out1, out2
def c_support_code(self):
return blas_header_text()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论