提交 671275d9 authored 作者: Benjamin Scellier's avatar Benjamin Scellier

file theano/gpuarray/opt.py

上级 25652690
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
import copy import copy
import numpy import numpy as np
import logging import logging
import pdb import pdb
import time import time
...@@ -622,7 +622,7 @@ def local_gpualloc_memset_0(node): ...@@ -622,7 +622,7 @@ def local_gpualloc_memset_0(node):
inp = node.inputs[0] inp = node.inputs[0]
if (isinstance(inp, GpuArrayConstant) and if (isinstance(inp, GpuArrayConstant) and
inp.data.size == 1 and inp.data.size == 1 and
(numpy.asarray(inp.data) == 0).all()): (np.asarray(inp.data) == 0).all()):
new_op = gpu_alloc(node.op.context_name, memset_0=True) new_op = gpu_alloc(node.op.context_name, memset_0=True)
return [new_op(*node.inputs)] return [new_op(*node.inputs)]
...@@ -632,7 +632,7 @@ def local_gpualloc_memset_0(node): ...@@ -632,7 +632,7 @@ def local_gpualloc_memset_0(node):
def local_gpua_alloc_empty_to_zeros(node): def local_gpua_alloc_empty_to_zeros(node):
if isinstance(node.op, GpuAllocEmpty): if isinstance(node.op, GpuAllocEmpty):
context_name = infer_context_name(*node.inputs) context_name = infer_context_name(*node.inputs)
z = numpy.asarray(0, dtype=node.outputs[0].dtype) z = np.asarray(0, dtype=node.outputs[0].dtype)
return [gpu_alloc(context_name)(as_gpuarray_variable(z, context_name), return [gpu_alloc(context_name)(as_gpuarray_variable(z, context_name),
*node.inputs)] *node.inputs)]
optdb.register('local_gpua_alloc_empty_to_zeros', optdb.register('local_gpua_alloc_empty_to_zeros',
...@@ -830,7 +830,7 @@ def local_gpua_shape_graph(op, context_name, inputs, outputs): ...@@ -830,7 +830,7 @@ def local_gpua_shape_graph(op, context_name, inputs, outputs):
def gpu_print_wrapper(op, cnda): def gpu_print_wrapper(op, cnda):
op.old_op.global_fn(op.old_op, numpy.asarray(cnda)) op.old_op.global_fn(op.old_op, np.asarray(cnda))
@register_opt('fast_compile') @register_opt('fast_compile')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论