提交 f95dc2d5 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix a stability optimization in nnet.

The special casting case is unneeded, and makes things ugly when infer_shape of subtensor is implemented.
上级 7d9b8e83
...@@ -1103,14 +1103,6 @@ def local_argmax_pushdown(node): ...@@ -1103,14 +1103,6 @@ def local_argmax_pushdown(node):
def _check_rows_is_arange_len_labels(rows, labels): def _check_rows_is_arange_len_labels(rows, labels):
'''Check that 'rows' is the same node as T.arange(labels.shape[0])''' '''Check that 'rows' is the same node as T.arange(labels.shape[0])'''
# this is admittedly a pretty random thing to have here... but it's not wrong (I think)
# and it has the effect of making the advanced_indexing -> crossentropy optimization work
# in the case where the labels are float32s casted to integers. "Why would anyone do that?"
# you ask... it is a handy trick for storing labels on a pre-FERMI GPU device so that
# logistic regression goes faster.
if labels.owner and labels.owner.op == tensor._convert_to_int32:
labels = labels.owner.inputs[0]
if rows.owner and isinstance(rows.owner.op, tensor.ARange): if rows.owner and isinstance(rows.owner.op, tensor.ARange):
start, stop, step = rows.owner.inputs start, stop, step = rows.owner.inputs
if getattr(start, 'data', None) != 0: #constants will have data if getattr(start, 'data', None) != 0: #constants will have data
...@@ -1119,11 +1111,12 @@ def _check_rows_is_arange_len_labels(rows, labels): ...@@ -1119,11 +1111,12 @@ def _check_rows_is_arange_len_labels(rows, labels):
return False return False
if not stop.owner: if not stop.owner:
return False return False
# Not sure if that case happens any more after the introduction
# of ShapeOptimizer # Not sure if that case happens any more after the introduction of
# ShapeOptimizer, but we keep it if ShapeOptimizer is not present
if isinstance(stop.owner.op, tensor.Subtensor): if isinstance(stop.owner.op, tensor.Subtensor):
shape_subtensor = stop.owner shape_subtensor = stop.owner
if shape_subtensor.op.idx_list == [0]: if list(shape_subtensor.op.idx_list) == [0]:
shape_var, = shape_subtensor.inputs shape_var, = shape_subtensor.inputs
if shape_var.owner and shape_var.owner.op == tensor._shape: if shape_var.owner and shape_var.owner.op == tensor._shape:
return shape_var.owner.inputs[0] is labels return shape_var.owner.inputs[0] is labels
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论