提交 a552c234 authored 作者: Frederic's avatar Frederic

fix code review code clean up.

上级 58ddeb48
......@@ -54,10 +54,10 @@ class SoftmaxWithBias(gof.Op):
x = tensor.as_tensor_variable(x)
b = tensor.as_tensor_variable(b)
if x.type.ndim != 2 \
or x.type.dtype not in ['float32', 'float64']:
or x.type.dtype not in tensor.float_dtypes:
raise ValueError('x must be 2-d tensor of floats')
if b.type.ndim != 1 \
or x.type.dtype not in ['float32', 'float64']:
or x.type.dtype not in tensor.float_dtypes:
raise ValueError('b must be 1-d tensor of floats')
sm = x.type.make_variable()
......@@ -351,7 +351,7 @@ class Softmax(gof.Op):
def make_node(self, x):
x = tensor.as_tensor_variable(x)
if x.type.ndim not in (1, 2) \
or x.type.dtype not in ['float32', 'float64']:
or x.type.dtype not in tensor.float_dtypes:
raise ValueError('x must be 1-d or 2-d tensor of floats')
if x.ndim == 1:
x = tensor.shape_padleft(x, n_ones=1)
......@@ -746,10 +746,10 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op):
b = tensor.as_tensor_variable(b)
y_idx = tensor.as_tensor_variable(y_idx)
if x.type.ndim != 2 \
or x.type.dtype not in ['float32', 'float64']:
or x.type.dtype not in tensor.float_dtypes:
raise ValueError('x must be 2-d tensor of floats', x.type)
if b.type.ndim != 1 \
or x.type.dtype not in ['float32', 'float64']:
or x.type.dtype not in tensor.float_dtypes:
raise ValueError('b must be 1-d tensor of floats', b.type)
if y_idx.type.ndim != 1 \
or y_idx.type.dtype not in tensor.discrete_dtypes:
......@@ -974,11 +974,11 @@ class CrossentropySoftmax1HotWithBiasDx (gof.Op):
sm = tensor.as_tensor_variable(sm)
y_idx = tensor.as_tensor_variable(y_idx)
if (dy.type.ndim != 1 or
dy.type.dtype not in ['float32', 'float64']):
dy.type.dtype not in tensor.float_dtypes):
raise ValueError('dy must be 1-d tensor of floats', dy.type)
if (sm.type.ndim != 2 or
sm.type.dtype not in ['float32', 'float64']):
raise ValueError('sm must be 1-d tensor of floats', sm.type)
sm.type.dtype not in tensor.float_dtypes):
raise ValueError('sm must be 2-d tensor of floats', sm.type)
if (y_idx.type.ndim != 1 or
y_idx.type.dtype not in tensor.discrete_dtypes):
raise ValueError('y_idx must be 1-d tensor of [u]ints', y_idx.type)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论