提交 ec9ece68 authored 作者: James Bergstra's avatar James Bergstra

nnet opt - hacked _check_rows... helper function so that the advanced_indexing

-> crossentropy optimization still works when storing labels on the GPU.
上级 ab5a218b
...@@ -923,13 +923,20 @@ def local_argmax_pushdown(node): ...@@ -923,13 +923,20 @@ def local_argmax_pushdown(node):
def _check_rows_is_arange_len_labels(rows, labels): def _check_rows_is_arange_len_labels(rows, labels):
'''Check that 'rows' is the same node as T.arange(labels.shape[0])''' '''Check that 'rows' is the same node as T.arange(labels.shape[0])'''
# this is admittedly a pretty random thing to have here... but it's not wrong (I think)
# and it has the effect of making the advanced_indexing -> crossentropy optimization work
# in the case where the labels are float32s casted to integers. "Why would anyone do that?"
# you ask... it is a handy trick for storing labels on a pre-FERMI GPU device so that
# logistic regression goes faster.
if labels.owner and labels.owner.op == tensor._convert_to_int32:
labels = labels.owner.inputs[0]
if rows.owner and isinstance(rows.owner.op, tensor.ARange): if rows.owner and isinstance(rows.owner.op, tensor.ARange):
start, stop, step = rows.owner.inputs start, stop, step = rows.owner.inputs
if getattr(start, 'data', None) != 0: #constants will have data if getattr(start, 'data', None) != 0: #constants will have data
return False return False
if getattr(step, 'data', None) != 1: # constant step will have data if getattr(step, 'data', None) != 1: # constant step will have data
return False return False
if stop.owner and isinstance(stop.owner.op, tensor.Subtensor): if stop.owner and isinstance(stop.owner.op, tensor.Subtensor):
shape_subtensor = stop.owner shape_subtensor = stop.owner
if shape_subtensor.op.idx_list == [0]: if shape_subtensor.op.idx_list == [0]:
...@@ -976,6 +983,7 @@ def local_advanced_indexing_crossentropy_onehot(node): ...@@ -976,6 +983,7 @@ def local_advanced_indexing_crossentropy_onehot(node):
except: except:
pass pass
if sm is not None and sm.owner and sm.owner.op in (softmax, softmax_with_bias): if sm is not None and sm.owner and sm.owner.op in (softmax, softmax_with_bias):
sm_w_bias = local_softmax_with_bias.transform(sm.owner) sm_w_bias = local_softmax_with_bias.transform(sm.owner)
if sm_w_bias: if sm_w_bias:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论