提交 f4f8b257 authored 作者: Frederic Bastien's avatar Frederic Bastien

small fix following review

上级 a51c01ea
......@@ -268,7 +268,7 @@ class FunctionGraph(utils.object2):
r
Variable.
new_client
(node, i) pairs such that node.inputs[i] is r.
(node, i) pair such that node.inputs[i] is r.
"""
# Ne need to do the assert as it is always True. The logic
......@@ -293,8 +293,8 @@ class FunctionGraph(utils.object2):
----------
r : Variable
The clients of r will be removed.
client_to_remove : (op, i) pairs
(op, i) pairs such that node.inputs[i] is not r anymore.
client_to_remove : (op, i) pair
(op, i) pair such that node.inputs[i] is not r anymore.
prune : bool
If prune is True, it remove r from this fgraph if it don't
have clients left.
......
......@@ -299,7 +299,6 @@ class GraphToGPU(Optimizer):
# Iterating through inputs of graph
target = infer_context_name(*fgraph.inputs)
for i in fgraph.inputs:
# Do not move *int* scalar to the GPU.
if isinstance(i.type, tensor.TensorType) and move_to_gpu(i):
mapping[i] = i.transfer(getattr(i.tag, 'target', target))
else:
......
......@@ -581,25 +581,33 @@ class GpuArraySharedVariable(_operators, SharedVariable):
GpuArrayType.SharedVariable = GpuArraySharedVariable
notset = object()
def gpuarray_shared_constructor(value, name=None, strict=False,
allow_downcast=None, borrow=False,
broadcastable=None, target=None):
broadcastable=None, target=notset):
"""
SharedVariable constructor for GpuArrayType.
See :func:`theano.shared`.
:target: default None
The device target. As None is a valid value and we need to
differentiate from the parameter notset and None, we use a
notset object.
"""
if target is None and not move_to_gpu(value):
raise TypeError('We do not move that data by deault to the GPU')
if target == 'gpu' or target == 'cpu':
raise TypeError('not for me')
if not isinstance(value, (numpy.ndarray, pygpu.gpuarray.GpuArray)):
raise TypeError('ndarray or GpuArray required')
if target is notset:
target = None
if not move_to_gpu(value):
raise TypeError('We do not move that data by default to the GPU')
try:
get_context(target)
except ContextNotDefined:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论