提交 000fa06c authored 作者: Frederic's avatar Frederic

Code cleanup and make SoftmaxGrad add DimShuffle as Sofmax

上级 7a9862f7
...@@ -1729,6 +1729,8 @@ if True: ...@@ -1729,6 +1729,8 @@ if True:
for n in node.inputs: for n in node.inputs:
if isinstance(n.owner.op, HostFromGpu): if isinstance(n.owner.op, HostFromGpu):
n = n.owner.inputs[0] n = n.owner.inputs[0]
if n.ndim != 2:
return
ins.append(n.dimshuffle(0, 1, 'x', 'x')) ins.append(n.dimshuffle(0, 1, 'x', 'x'))
out = GpuDnnSoftmaxGrad( out = GpuDnnSoftmaxGrad(
......
...@@ -279,21 +279,19 @@ class SoftmaxGrad(gof.Op): ...@@ -279,21 +279,19 @@ class SoftmaxGrad(gof.Op):
nin = 2 nin = 2
nout = 1 nout = 1
def __init__(self, **kwargs): __props__ = ()
gof.Op.__init__(self, **kwargs)
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return tensor.hashtype(self)
def __str__(self):
return self.__class__.__name__
def make_node(self, dy, sm, **kwargs): def make_node(self, dy, sm):
dy = tensor.as_tensor_variable(dy) dy = tensor.as_tensor_variable(dy)
sm = tensor.as_tensor_variable(sm) sm = tensor.as_tensor_variable(sm)
if dy.type.ndim not in (1, 2) \
or dy.type.dtype not in tensor.float_dtypes:
raise ValueError('dy must be 1-d or 2-d tensor of floats. Got ',
dy.type)
if dy.ndim == 1:
dy = tensor.shape_padleft(dy, n_ones=1)
if sm.ndim == 1:
sm = tensor.shape_padleft(sm, n_ones=1)
return Apply(self, [dy, sm], [sm.type.make_variable()]) return Apply(self, [dy, sm], [sm.type.make_variable()])
def perform(self, node, input_storage, output_storage): def perform(self, node, input_storage, output_storage):
...@@ -394,24 +392,14 @@ class Softmax(gof.Op): ...@@ -394,24 +392,14 @@ class Softmax(gof.Op):
nin = 1 nin = 1
nout = 1 nout = 1
__props__ = ()
def __init__(self, **kwargs):
gof.Op.__init__(self, **kwargs)
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def __str__(self):
return self.__class__.__name__
def make_node(self, x): def make_node(self, x):
x = tensor.as_tensor_variable(x) x = tensor.as_tensor_variable(x)
if x.type.ndim not in (1, 2) \ if x.type.ndim not in (1, 2) \
or x.type.dtype not in tensor.float_dtypes: or x.type.dtype not in tensor.float_dtypes:
raise ValueError('x must be 1-d or 2-d tensor of floats. Got ', x.type) raise ValueError('x must be 1-d or 2-d tensor of floats. Got ',
x.type)
if x.ndim == 1: if x.ndim == 1:
x = tensor.shape_padleft(x, n_ones=1) x = tensor.shape_padleft(x, n_ones=1)
return Apply(self, [x], [x.type()]) return Apply(self, [x], [x.type()])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论