提交 14e3a1e8 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix alpha_merge and output_merge.

上级 ecba472d
...@@ -385,17 +385,17 @@ _zero = constant(numpy.asarray(0.0, dtype='float64')) ...@@ -385,17 +385,17 @@ _zero = constant(numpy.asarray(0.0, dtype='float64'))
_one = constant(numpy.asarray(1.0, dtype='float64')) _one = constant(numpy.asarray(1.0, dtype='float64'))
def ensure_double(val, default, name): def ensure_dt(val, default, name, dtype):
if val is None: if val is None:
return default.clone() val = default.clone()
if not isinstance(val, Variable): if not isinstance(val, Variable):
val = constant(val).astype('float64') val = constant(val)
if hasattr(val, 'ndim') and val.ndim == 0: if hasattr(val, 'ndim') and val.ndim == 0:
val = as_scalar(val) val = as_scalar(val)
if not isinstance(val.type, theano.scalar.Scalar): if not isinstance(val.type, theano.scalar.Scalar):
raise TypeError("%s: expected a scalar value" % (name,)) raise TypeError("%s: expected a scalar value" % (name,))
if not val.type.dtype == 'float64': if not val.type.dtype == dtype:
raise TypeError("%s: type is not float64" % (name,)) val = val.astype(dtype)
return val return val
...@@ -456,8 +456,8 @@ class GpuDnnConv(DnnBase, COp): ...@@ -456,8 +456,8 @@ class GpuDnnConv(DnnBase, COp):
or desc.type.ctype != 'cudnnConvolutionDescriptor_t': or desc.type.ctype != 'cudnnConvolutionDescriptor_t':
raise TypeError('desc must be cudnnConvolutionDescriptor_t') raise TypeError('desc must be cudnnConvolutionDescriptor_t')
alpha = ensure_double(alpha, _one, 'alpha') alpha = ensure_dt(alpha, _one, 'alpha', img.dtype)
beta = ensure_double(beta, _zero, 'beta') beta = ensure_dt(beta, _zero, 'beta', img.dtype)
return Apply(self, [img, kern, output, desc, alpha, beta], return Apply(self, [img, kern, output, desc, alpha, beta],
[output.type()]) [output.type()])
...@@ -577,8 +577,8 @@ class GpuDnnConvGradW(DnnBase, COp): ...@@ -577,8 +577,8 @@ class GpuDnnConvGradW(DnnBase, COp):
or desc.type.ctype != 'cudnnConvolutionDescriptor_t': or desc.type.ctype != 'cudnnConvolutionDescriptor_t':
raise TypeError('desc must be cudnnConvolutionDescriptor_t') raise TypeError('desc must be cudnnConvolutionDescriptor_t')
alpha = ensure_double(alpha, _one, 'alpha') alpha = ensure_dt(alpha, _one, 'alpha', img.dtype)
beta = ensure_double(beta, _zero, 'beta') beta = ensure_dt(beta, _zero, 'beta', img.dtype)
return Apply(self, [img, topgrad, output, desc, alpha, beta], return Apply(self, [img, topgrad, output, desc, alpha, beta],
[output.type()]) [output.type()])
...@@ -644,8 +644,8 @@ class GpuDnnConvGradI(DnnBase): ...@@ -644,8 +644,8 @@ class GpuDnnConvGradI(DnnBase):
or desc.type.ctype != 'cudnnConvolutionDescriptor_t': or desc.type.ctype != 'cudnnConvolutionDescriptor_t':
raise TypeError('desc must be cudnnConvolutionDescriptor_t') raise TypeError('desc must be cudnnConvolutionDescriptor_t')
alpha = ensure_double(alpha, _one, 'alpha') alpha = ensure_dt(alpha, _one, 'alpha', kern.dtype)
beta = ensure_double(beta, _zero, 'beta') beta = ensure_dt(beta, _zero, 'beta', kern.dtype)
return Apply(self, [kern, topgrad, output, desc, alpha, beta], return Apply(self, [kern, topgrad, output, desc, alpha, beta],
[output.type()]) [output.type()])
......
...@@ -7,10 +7,10 @@ from theano.gof import local_optimizer ...@@ -7,10 +7,10 @@ from theano.gof import local_optimizer
from theano.tensor import (DimShuffle, get_scalar_constant_value, from theano.tensor import (DimShuffle, get_scalar_constant_value,
NotScalarConstantError) NotScalarConstantError)
from .basic_ops import GpuFromHost, HostFromGpu, host_from_gpu from .basic_ops import GpuFromHost, HostFromGpu
from .elemwise import GpuDimShuffle, GpuElemwise from .elemwise import GpuDimShuffle, GpuElemwise
_one = scal.constant(numpy.asarray(1.0, dtype='float32')) _one = scal.constant(numpy.asarray(1.0, dtype='float64'))
def grab_cpu_scalar(v, nd): def grab_cpu_scalar(v, nd):
...@@ -18,10 +18,10 @@ def grab_cpu_scalar(v, nd): ...@@ -18,10 +18,10 @@ def grab_cpu_scalar(v, nd):
n = v.owner n = v.owner
if (isinstance(n.op, GpuDimShuffle) and if (isinstance(n.op, GpuDimShuffle) and
n.op.new_order == ('x',) * nd): n.op.new_order == ('x',) * nd):
return host_from_gpu(n.inputs[0]) return grab_cpu_scalar(n.inputs[0])
elif (isinstance(n.op, DimShuffle) and elif (isinstance(n.op, DimShuffle) and
n.op.new_order == ('x',) * nd): n.op.new_order == ('x',) * nd):
return n.inputs[0] return grab_cpu_scalar(n.inputs[0])
elif isinstance(n.op, GpuFromHost): elif isinstance(n.op, GpuFromHost):
return grab_cpu_scalar(n.inputs[0], nd=nd) return grab_cpu_scalar(n.inputs[0], nd=nd)
else: else:
...@@ -37,7 +37,7 @@ def find_node(v, cls, ignore_clients=False): ...@@ -37,7 +37,7 @@ def find_node(v, cls, ignore_clients=False):
# that has the op class specified. If ignore_clients is False (the # that has the op class specified. If ignore_clients is False (the
# default) it will only dig through nodes that have a single # default) it will only dig through nodes that have a single
# client. # client.
if v.owner is not None and (ignore_clients or v.clients == 1): if v.owner is not None and (ignore_clients or len(v.clients) == 1):
if isinstance(v.owner.op, cls): if isinstance(v.owner.op, cls):
return v.owner return v.owner
elif (isinstance(v.owner.op, GpuFromHost) and elif (isinstance(v.owner.op, GpuFromHost) and
......
...@@ -593,7 +593,7 @@ def get_scalar_constant_value(orig_v, elemwise=True, ...@@ -593,7 +593,7 @@ def get_scalar_constant_value(orig_v, elemwise=True,
# mess with the stabilization optimization and be too slow. # mess with the stabilization optimization and be too slow.
# We put all the scalar Ops used by get_canonical_form_slice() # We put all the scalar Ops used by get_canonical_form_slice()
# to allow it to determine the broadcast pattern correctly. # to allow it to determine the broadcast pattern correctly.
elif isinstance(v.owner.op, ScalarFromTensor): elif isinstance(v.owner.op, (ScalarFromTensor, TensorFromScalar)):
return get_scalar_constant_value(v.owner.inputs[0]) return get_scalar_constant_value(v.owner.inputs[0])
elif isinstance(v.owner.op, scal.ScalarOp): elif isinstance(v.owner.op, scal.ScalarOp):
if isinstance(v.owner.op, scal.Second): if isinstance(v.owner.op, scal.Second):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论