提交 c90ef03d authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Replace theano.tensor alias T with tt in theano.sandbox

上级 99667825
import copy
import numpy as np
import warnings
import theano
import theano.tensor as tt
from theano import Op, Apply
import theano.tensor as T
from theano.scalar import as_scalar
import copy
class MultinomialFromUniform(Op):
# TODO : need description for parameter 'odtype'
"""
Converts samples from a uniform into sample from a multinomial.
TODO : need description for parameter 'odtype'
"""
__props__ = ("odtype",)
......@@ -31,8 +33,8 @@ class MultinomialFromUniform(Op):
self.odtype = "auto"
def make_node(self, pvals, unis, n=1):
pvals = T.as_tensor_variable(pvals)
unis = T.as_tensor_variable(unis)
pvals = tt.as_tensor_variable(pvals)
unis = tt.as_tensor_variable(unis)
if pvals.ndim != 2:
raise NotImplementedError("pvals ndim should be 2", pvals.ndim)
if unis.ndim != 1:
......@@ -41,16 +43,16 @@ class MultinomialFromUniform(Op):
odtype = pvals.dtype
else:
odtype = self.odtype
out = T.tensor(dtype=odtype, broadcastable=pvals.type.broadcastable)
out = tt.tensor(dtype=odtype, broadcastable=pvals.type.broadcastable)
return Apply(self, [pvals, unis, as_scalar(n)], [out])
def grad(self, ins, outgrads):
pvals, unis, n = ins
(gz,) = outgrads
return [
T.zeros_like(x, dtype=theano.config.floatX)
if x.dtype in T.discrete_dtypes
else T.zeros_like(x)
tt.zeros_like(x, dtype=theano.config.floatX)
if x.dtype in tt.discrete_dtypes
else tt.zeros_like(x)
for x in ins
]
......@@ -237,8 +239,8 @@ class ChoiceFromUniform(MultinomialFromUniform):
self.replace = False
def make_node(self, pvals, unis, n=1):
pvals = T.as_tensor_variable(pvals)
unis = T.as_tensor_variable(unis)
pvals = tt.as_tensor_variable(pvals)
unis = tt.as_tensor_variable(unis)
if pvals.ndim != 2:
raise NotImplementedError("pvals ndim should be 2", pvals.ndim)
if unis.ndim != 1:
......@@ -247,7 +249,7 @@ class ChoiceFromUniform(MultinomialFromUniform):
odtype = "int64"
else:
odtype = self.odtype
out = T.tensor(dtype=odtype, broadcastable=pvals.type.broadcastable)
out = tt.tensor(dtype=odtype, broadcastable=pvals.type.broadcastable)
return Apply(self, [pvals, unis, as_scalar(n)], [out])
def c_code_cache_version(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论