提交 a552c234 authored 作者: Frederic's avatar Frederic

fix code review code clean up.

上级 58ddeb48
...@@ -54,10 +54,10 @@ class SoftmaxWithBias(gof.Op): ...@@ -54,10 +54,10 @@ class SoftmaxWithBias(gof.Op):
x = tensor.as_tensor_variable(x) x = tensor.as_tensor_variable(x)
b = tensor.as_tensor_variable(b) b = tensor.as_tensor_variable(b)
if x.type.ndim != 2 \ if x.type.ndim != 2 \
or x.type.dtype not in ['float32', 'float64']: or x.type.dtype not in tensor.float_dtypes:
raise ValueError('x must be 2-d tensor of floats') raise ValueError('x must be 2-d tensor of floats')
if b.type.ndim != 1 \ if b.type.ndim != 1 \
or x.type.dtype not in ['float32', 'float64']: or x.type.dtype not in tensor.float_dtypes:
raise ValueError('b must be 1-d tensor of floats') raise ValueError('b must be 1-d tensor of floats')
sm = x.type.make_variable() sm = x.type.make_variable()
...@@ -351,7 +351,7 @@ class Softmax(gof.Op): ...@@ -351,7 +351,7 @@ class Softmax(gof.Op):
def make_node(self, x): def make_node(self, x):
x = tensor.as_tensor_variable(x) x = tensor.as_tensor_variable(x)
if x.type.ndim not in (1, 2) \ if x.type.ndim not in (1, 2) \
or x.type.dtype not in ['float32', 'float64']: or x.type.dtype not in tensor.float_dtypes:
raise ValueError('x must be 1-d or 2-d tensor of floats') raise ValueError('x must be 1-d or 2-d tensor of floats')
if x.ndim == 1: if x.ndim == 1:
x = tensor.shape_padleft(x, n_ones=1) x = tensor.shape_padleft(x, n_ones=1)
...@@ -746,10 +746,10 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op): ...@@ -746,10 +746,10 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op):
b = tensor.as_tensor_variable(b) b = tensor.as_tensor_variable(b)
y_idx = tensor.as_tensor_variable(y_idx) y_idx = tensor.as_tensor_variable(y_idx)
if x.type.ndim != 2 \ if x.type.ndim != 2 \
or x.type.dtype not in ['float32', 'float64']: or x.type.dtype not in tensor.float_dtypes:
raise ValueError('x must be 2-d tensor of floats', x.type) raise ValueError('x must be 2-d tensor of floats', x.type)
if b.type.ndim != 1 \ if b.type.ndim != 1 \
or x.type.dtype not in ['float32', 'float64']: or x.type.dtype not in tensor.float_dtypes:
raise ValueError('b must be 1-d tensor of floats', b.type) raise ValueError('b must be 1-d tensor of floats', b.type)
if y_idx.type.ndim != 1 \ if y_idx.type.ndim != 1 \
or y_idx.type.dtype not in tensor.discrete_dtypes: or y_idx.type.dtype not in tensor.discrete_dtypes:
...@@ -974,11 +974,11 @@ class CrossentropySoftmax1HotWithBiasDx (gof.Op): ...@@ -974,11 +974,11 @@ class CrossentropySoftmax1HotWithBiasDx (gof.Op):
sm = tensor.as_tensor_variable(sm) sm = tensor.as_tensor_variable(sm)
y_idx = tensor.as_tensor_variable(y_idx) y_idx = tensor.as_tensor_variable(y_idx)
if (dy.type.ndim != 1 or if (dy.type.ndim != 1 or
dy.type.dtype not in ['float32', 'float64']): dy.type.dtype not in tensor.float_dtypes):
raise ValueError('dy must be 1-d tensor of floats', dy.type) raise ValueError('dy must be 1-d tensor of floats', dy.type)
if (sm.type.ndim != 2 or if (sm.type.ndim != 2 or
sm.type.dtype not in ['float32', 'float64']): sm.type.dtype not in tensor.float_dtypes):
raise ValueError('sm must be 1-d tensor of floats', sm.type) raise ValueError('sm must be 2-d tensor of floats', sm.type)
if (y_idx.type.ndim != 1 or if (y_idx.type.ndim != 1 or
y_idx.type.dtype not in tensor.discrete_dtypes): y_idx.type.dtype not in tensor.discrete_dtypes):
raise ValueError('y_idx must be 1-d tensor of [u]ints', y_idx.type) raise ValueError('y_idx must be 1-d tensor of [u]ints', y_idx.type)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论