Fix apply node setup and connection_pattern in GpuDnnTransformerGradI

上级 a9181f25
...@@ -2960,7 +2960,7 @@ class GpuDnnTransformerGradI(DnnBase): ...@@ -2960,7 +2960,7 @@ class GpuDnnTransformerGradI(DnnBase):
""" """
__props__ = ('dtype',) __props__ = ('dtype',)
_cop_num_inputs = 8 _cop_num_inputs = 8
_cop_num_outputs = 1 _cop_num_outputs = 2
_f16_ok = True _f16_ok = True
def __init__(self, dtype=theano.config.floatX): def __init__(self, dtype=theano.config.floatX):
...@@ -2980,6 +2980,8 @@ class GpuDnnTransformerGradI(DnnBase): ...@@ -2980,6 +2980,8 @@ class GpuDnnTransformerGradI(DnnBase):
grid = as_gpuarray_variable(gpu_contiguous(grid), context_name) grid = as_gpuarray_variable(gpu_contiguous(grid), context_name)
grid_dims = as_tensor_variable(grid_dims) grid_dims = as_tensor_variable(grid_dims)
dy = as_gpuarray_variable(dy, context_name) dy = as_gpuarray_variable(dy, context_name)
alpha = as_scalar(alpha)
beta = as_scalar(beta)
dimg = GpuArrayType(dtype=self.dtype, dimg = GpuArrayType(dtype=self.dtype,
broadcastable=img.type.ndim * (False,), broadcastable=img.type.ndim * (False,),
...@@ -2995,7 +2997,7 @@ class GpuDnnTransformerGradI(DnnBase): ...@@ -2995,7 +2997,7 @@ class GpuDnnTransformerGradI(DnnBase):
def connection_pattern(self, node): def connection_pattern(self, node):
# not connected to desc # not connected to desc
return [[1], [1], [1], [1], [1], [0], [1], [1]] return [[1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [0, 0], [1, 1], [1, 1]]
class GpuDnnTransformerGradT(DnnBase): class GpuDnnTransformerGradT(DnnBase):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论