提交 4a2c73ef authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Use __call__ instead of make_node to avoid bypassing test values

上级 df02fe22
...@@ -1597,7 +1597,7 @@ class GpuDnnSoftmax(GpuDnnSoftmaxBase): ...@@ -1597,7 +1597,7 @@ class GpuDnnSoftmax(GpuDnnSoftmaxBase):
def grad(self, inp, grads): def grad(self, inp, grads):
x, = inp x, = inp
g_sm, = grads g_sm, = grads
sm = self.make_node(x).outputs[0] sm = self(x)
return [GpuDnnSoftmaxGrad( return [GpuDnnSoftmaxGrad(
self.algo, self.algo,
self.mode self.mode
...@@ -1685,7 +1685,7 @@ class GpuDnnBatchNorm(DnnBase): ...@@ -1685,7 +1685,7 @@ class GpuDnnBatchNorm(DnnBase):
def grad(self, inputs, grads): def grad(self, inputs, grads):
x, scale, bias, epsilon = inputs x, scale, bias, epsilon = inputs
dy = grads[0] dy = grads[0]
_, x_mean, x_invstd = self.make_node(x, scale, bias, epsilon).outputs _, x_mean, x_invstd = self(x, scale, bias, epsilon)
return GpuDnnBatchNormGrad(self.mode)(x, dy, scale, x_mean, return GpuDnnBatchNormGrad(self.mode)(x, dy, scale, x_mean,
x_invstd, epsilon) + [DisconnectedType()()] x_invstd, epsilon) + [DisconnectedType()()]
......
...@@ -2248,7 +2248,7 @@ err%(name)s = cudnnSoftmaxForward( ...@@ -2248,7 +2248,7 @@ err%(name)s = cudnnSoftmaxForward(
def grad(self, inp, grads): def grad(self, inp, grads):
x, = inp x, = inp
g_sm, = grads g_sm, = grads
sm = self.make_node(x).outputs[0] sm = self(x)
return [GpuDnnSoftmaxGrad( return [GpuDnnSoftmaxGrad(
self.tensor_format, self.tensor_format,
self.algo, self.algo,
...@@ -2603,7 +2603,7 @@ err%(name)s = cudnnBatchNormalizationForwardTraining( ...@@ -2603,7 +2603,7 @@ err%(name)s = cudnnBatchNormalizationForwardTraining(
def grad(self, inputs, grads): def grad(self, inputs, grads):
x, scale, bias = inputs x, scale, bias = inputs
dy = grads[0] dy = grads[0]
_, x_mean, x_invstd = self.make_node(x, scale, bias).outputs _, x_mean, x_invstd = self(x, scale, bias)
return GpuDnnBatchNormGrad(self.mode, self.epsilon)(x, dy, scale, return GpuDnnBatchNormGrad(self.mode, self.epsilon)(x, dy, scale,
x_mean, x_invstd) x_mean, x_invstd)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论