提交 13f31289 authored 作者: Sina Honari's avatar Sina Honari 提交者: Pascal Lamblin

initial commit for gradient for MRG_RandomStreams distribution variables should…

initial commit for gradient for MRG_RandomStreams distribution variables should be null - issue #4078
上级 a5c029dc
...@@ -1994,6 +1994,38 @@ def zero_grad(x): ...@@ -1994,6 +1994,38 @@ def zero_grad(x):
return zero_grad_(x) return zero_grad_(x)
class UndefinedGrad(ViewOp):
def grad(self, args, g_outs):
return [grad_undefined(self, i, inp[i]) for i in xrange(args)]
def R_op(self, inputs, eval_points):
return [None]
def connection_pattern(self, node):
return [[False]]
undefined_grad_ = UndefinedGrad()
def undefined_grad(x):
"""
Consider the gradient of this variable undefined and
generate an error message if its gradient is taken.
The expression itself is unaffected, but when its gradient is
computed, or the gradient of another expression that this
expression is a subexpression of, an error message will be generated
specifying such gradient is not defined.
:param x: A Theano expression whose gradient should be undefined.
:return: The expression is returned unmodified, but its gradient
is now undefined.
"""
return undefined_grad_(x)
class DisconnectedGrad(ViewOp): class DisconnectedGrad(ViewOp):
def grad(self, args, g_outs): def grad(self, args, g_outs):
return [disconnected_type() for g_out in g_outs] return [disconnected_type() for g_out in g_outs]
......
...@@ -15,6 +15,7 @@ from six.moves import xrange ...@@ -15,6 +15,7 @@ from six.moves import xrange
import theano import theano
from theano import Op, Apply, shared, config, Variable from theano import Op, Apply, shared, config, Variable
from theano import gradient, function from theano import gradient, function
from theano.gradient import undefined_grad
from theano import tensor from theano import tensor
from theano.tensor import (TensorType, as_tensor_variable, get_vector_length, from theano.tensor import (TensorType, as_tensor_variable, get_vector_length,
cast, opt, scal) cast, opt, scal)
...@@ -773,7 +774,9 @@ class MRG_RandomStreams(object): ...@@ -773,7 +774,9 @@ class MRG_RandomStreams(object):
""" """
low = as_tensor_variable(low) low = as_tensor_variable(low)
low = undefined_grad(low)
high = as_tensor_variable(high) high = as_tensor_variable(high)
high = undefined_grad(high)
if dtype is None: if dtype is None:
dtype = scal.upcast(config.floatX, low.dtype, high.dtype) dtype = scal.upcast(config.floatX, low.dtype, high.dtype)
...@@ -821,6 +824,7 @@ class MRG_RandomStreams(object): ...@@ -821,6 +824,7 @@ class MRG_RandomStreams(object):
nstreams=None): nstreams=None):
# TODO : need description for method, parameter and return # TODO : need description for method, parameter and return
if n == 1: if n == 1:
p = undefined_grad(p)
x = self.uniform(size=size, nstreams=nstreams) x = self.uniform(size=size, nstreams=nstreams)
return cast(x < p, dtype) return cast(x < p, dtype)
else: else:
...@@ -852,6 +856,7 @@ class MRG_RandomStreams(object): ...@@ -852,6 +856,7 @@ class MRG_RandomStreams(object):
if pvals is None: if pvals is None:
raise TypeError("You have to specify pvals") raise TypeError("You have to specify pvals")
pvals = as_tensor_variable(pvals) pvals = as_tensor_variable(pvals)
pvals = undefined_grad(pvals)
if size is not None: if size is not None:
if any([isinstance(i, integer_types) and i <= 0 for i in size]): if any([isinstance(i, integer_types) and i <= 0 for i in size]):
raise ValueError( raise ValueError(
...@@ -978,7 +983,9 @@ class MRG_RandomStreams(object): ...@@ -978,7 +983,9 @@ class MRG_RandomStreams(object):
# second half our U2's. See Wikipedia page: # second half our U2's. See Wikipedia page:
# http://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform # http://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
avg = as_tensor_variable(avg) avg = as_tensor_variable(avg)
avg = undefined_grad(avg)
std = as_tensor_variable(std) std = as_tensor_variable(std)
std = undefined_grad(std)
if dtype is None: if dtype is None:
dtype = scal.upcast(config.floatX, avg.dtype, std.dtype) dtype = scal.upcast(config.floatX, avg.dtype, std.dtype)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论