提交 a6165acc authored 作者: João Victor Risso's avatar João Victor Risso

Add check_input attribute to spatial transformer ops

上级 4089d4d9
......@@ -2798,6 +2798,7 @@ class GpuDnnTransformerGrid(DnnBase):
_cop_num_inputs = 3
_cop_num_outputs = 1
_f16_ok = True
check_input = False
def __init__(self):
DnnBase.__init__(self, ["c_code/dnn_sptf_grid.c"], "APPLY_SPECIFIC(dnn_sptf_grid)")
......@@ -2844,6 +2845,7 @@ class GpuDnnTransformerSampler(DnnBase):
_cop_num_inputs = 3
_cop_num_outputs = 1
_f16_ok = True
check_input = False
def __init__(self):
DnnBase.__init__(self, ["c_code/dnn_sptf_sampler.c"], "APPLY_SPECIFIC(dnn_sptf_sampler)")
......@@ -2895,6 +2897,7 @@ class GpuDnnTransformerGradI(DnnBase):
_cop_num_inputs = 4
_cop_num_outputs = 2
_f16_ok = True
check_input = False
def __init__(self, dtype=theano.config.floatX):
DnnBase.__init__(self, ["c_code/dnn_sptf_gi.c"], "APPLY_SPECIFIC(dnn_sptf_gi)")
......@@ -2939,6 +2942,7 @@ class GpuDnnTransformerGradT(DnnBase):
_cop_num_inputs = 2
_cop_num_outputs = 1
_f16_ok = True
check_input = False
def __init__(self):
DnnBase.__init__(self, ["c_code/dnn_sptf_gt.c"], "APPLY_SPECIFIC(dnn_sptf_gt)")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论