提交 2ccd511d authored 作者: Frederic Bastien's avatar Frederic Bastien

Force the abstract BN ops to have common dtype for tensors input.

上级 1e68d76b
......@@ -88,6 +88,14 @@ def upcast(dtype, *dtypes):
return rval
def as_common_dtype(*vars):
"""
For for theano.scalar.Scalar and TensorVariable.
"""
dtype = upcast(*[v.dtype for v in vars])
return (v.astype(dtype) for v in vars)
def get_scalar_type(dtype):
"""
Return a Scalar(dtype) object.
......
......@@ -7,7 +7,7 @@ from theano.gof.opt import copy_stack_trace
from theano.tensor import as_tensor_variable, TensorType
from theano.tensor import basic as T
from theano.tensor.opt import register_specialize_device
from theano.scalar import Composite
from theano.scalar import Composite, as_common_dtype
from theano.scalar import add, sub, true_div, mul
......@@ -427,6 +427,13 @@ class AbstractBatchNormTrain(Op):
(running_mean is not None and running_var is not None))
assert (running_mean is None or running_mean.ndim == x.ndim)
assert (running_var is None or running_var.ndim == x.ndim)
# Upcast to common dtype on the non-scalar
# Keep as is dtype of scalar (epsilon and running_average_factor)
if running_mean:
x, scale, bias, running_mean, running_var = as_common_dtype(
x, scale, bias, running_mean, running_var)
else:
x, scale, bias = as_common_dtype(x, scale, bias)
inputs = [x, scale, bias, epsilon, running_average_factor]
output_types = [x.type(), scale.type(), scale.type()]
if running_mean is not None and running_var is not None:
......@@ -524,6 +531,10 @@ class AbstractBatchNormInference(Op):
estimated_mean = as_tensor_variable(estimated_mean)
estimated_variance = as_tensor_variable(estimated_variance)
epsilon = as_tensor_variable(epsilon)
# Upcast to common dtype on the non-scalar
# Keep as is dtype of scalar (epsilon)
x, scale, bias, estimated_mean, estimated_variance = as_common_dtype(
x, scale, bias, estimated_mean, estimated_variance)
assert x.ndim == scale.ndim == bias.ndim == estimated_mean.ndim == estimated_variance.ndim
return Apply(self, [x, scale, bias, estimated_mean, estimated_variance, epsilon], [x.type()])
......@@ -578,6 +589,10 @@ class AbstractBatchNormTrainGrad(Op):
x_invstd = as_tensor_variable(x_invstd)
epsilon = as_tensor_variable(epsilon)
# Upcast to common dtype on the non-scalar
# Keep as is dtype of scalar (epsilon)
x, dy, scale, x_mean, x_invstd = as_common_dtype(
x, dy, scale, x_mean, x_invstd)
assert x.ndim == dy.ndim == scale.ndim == x_mean.ndim == x_invstd.ndim
return Apply(self, [x, dy, scale, x_mean, x_invstd, epsilon],
[x.type(), scale.type(), scale.type()])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论