提交 f4d876e9 authored 作者: Frederic Bastien's avatar Frederic Bastien

Add more and keep more values_eq_approx

上级 f4e42e17
...@@ -183,7 +183,11 @@ gpu_optimizer.register('local_remove_all_assert', ...@@ -183,7 +183,11 @@ gpu_optimizer.register('local_remove_all_assert',
# in order to avoid introducin new CPU Ops, or useless ones. # in order to avoid introducin new CPU Ops, or useless ones.
def safe_to_gpu(x, ctx_name): def safe_to_gpu(x, ctx_name):
if isinstance(x.type, tensor.TensorType): if isinstance(x.type, tensor.TensorType):
return GpuFromHost(ctx_name)(x) ret = GpuFromHost(ctx_name)(x)
values_eq_approx = getattr(x.tag, 'values_eq_approx', None)
if values_eq_approx:
ret.tag.values_eq_approx = values_eq_approx
return ret
else: else:
return x return x
......
...@@ -2,6 +2,9 @@ from __future__ import absolute_import, print_function, division ...@@ -2,6 +2,9 @@ from __future__ import absolute_import, print_function, division
import os import os
from string import Template from string import Template
import numpy as np
import theano
from theano import Apply from theano import Apply
from theano.tensor import as_tensor_variable from theano.tensor import as_tensor_variable
from theano.tensor.sort import TopKOp from theano.tensor.sort import TopKOp
...@@ -333,6 +336,21 @@ class GpuTopKOp(GpuKernelBase, TopKOp): ...@@ -333,6 +336,21 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
return node.inputs[0].type.context return node.inputs[0].type.context
class ValuesEqApproxNoOrder():
"""
We ignore the order of elements on a given axis during the comparison.
"""
def __init__(self, axis):
self.axis = axis
def __call__(self, val1, val2):
v1 = np.sort(val1, axis=self.axis)
v2 = np.sort(val2, axis=self.axis)
ret = theano.tensor.type.values_eq_approx(v1, v2)
return ret
@register_opt('fast_compile') @register_opt('fast_compile')
@op_lifter([TopKOp], cuda_only=True) @op_lifter([TopKOp], cuda_only=True)
@register_opt2([TopKOp], 'fast_compile') @register_opt2([TopKOp], 'fast_compile')
...@@ -350,5 +368,10 @@ def local_gpua_topkop(op, ctx_name, inputs, outputs): ...@@ -350,5 +368,10 @@ def local_gpua_topkop(op, ctx_name, inputs, outputs):
idx_dtype=op.idx_dtype, idx_dtype=op.idx_dtype,
return_values=rv, return_values=rv,
return_indices=ri) return_indices=ri)
rets = gpu_op(x, k) rets = gpu_op(x, k, return_list=True)
c = ValuesEqApproxNoOrder(axis)
if rv or ri:
rets[0].tag.values_eq_approx = c
if rv and ri:
rets[1].tag.values_eq_approx = c
return rets return rets
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论