Fix apply node setup and connection_pattern in GpuDnnTransformerGradI

上级 a9181f25
......@@ -2960,7 +2960,7 @@ class GpuDnnTransformerGradI(DnnBase):
"""
__props__ = ('dtype',)
_cop_num_inputs = 8
_cop_num_outputs = 1
_cop_num_outputs = 2
_f16_ok = True
def __init__(self, dtype=theano.config.floatX):
......@@ -2980,6 +2980,8 @@ class GpuDnnTransformerGradI(DnnBase):
grid = as_gpuarray_variable(gpu_contiguous(grid), context_name)
grid_dims = as_tensor_variable(grid_dims)
dy = as_gpuarray_variable(dy, context_name)
alpha = as_scalar(alpha)
beta = as_scalar(beta)
dimg = GpuArrayType(dtype=self.dtype,
broadcastable=img.type.ndim * (False,),
......@@ -2995,7 +2997,7 @@ class GpuDnnTransformerGradI(DnnBase):
def connection_pattern(self, node):
# not connected to desc
return [[1], [1], [1], [1], [1], [0], [1], [1]]
return [[1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [0, 0], [1, 1], [1, 1]]
class GpuDnnTransformerGradT(DnnBase):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论