提交 2cb3a154 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Replace theano.tensor alias T with tt in theano.typed_list

上级 219428ba
import numpy as np
from .type import TypedListType
import theano
import theano.tensor as tt
from theano.typed_list.type import TypedListType
from theano.gof import Apply, Constant, Op, Variable
from theano.tensor.type_other import SliceType
from theano import tensor as T
from theano.compile.debugmode import _lessbroken_deepcopy
......@@ -74,11 +75,11 @@ class GetItem(Op):
index = Constant(SliceType(), index)
return Apply(self, [x, index], [x.type()])
else:
index = T.constant(index, ndim=0, dtype="int64")
index = tt.constant(index, ndim=0, dtype="int64")
return Apply(self, [x, index], [x.ttype()])
if isinstance(index.type, SliceType):
return Apply(self, [x, index], [x.type()])
elif isinstance(index, T.TensorVariable) and index.ndim == 0:
elif isinstance(index, tt.TensorVariable) and index.ndim == 0:
assert index.dtype == "int64"
return Apply(self, [x, index], [x.ttype()])
else:
......@@ -325,10 +326,10 @@ class Insert(Op):
assert isinstance(x.type, TypedListType)
assert x.ttype == toInsert.type
if not isinstance(index, Variable):
index = T.constant(index, ndim=0, dtype="int64")
index = tt.constant(index, ndim=0, dtype="int64")
else:
assert index.dtype == "int64"
assert isinstance(index, T.TensorVariable) and index.ndim == 0
assert isinstance(index, tt.TensorVariable) and index.ndim == 0
return Apply(self, [x, index, toInsert], [x.type()])
def perform(self, node, inputs, outputs):
......@@ -536,7 +537,7 @@ class Index(Op):
def make_node(self, x, elem):
assert isinstance(x.type, TypedListType)
assert x.ttype == elem.type
return Apply(self, [x, elem], [T.scalar()])
return Apply(self, [x, elem], [tt.scalar()])
def perform(self, node, inputs, outputs):
"""
......@@ -565,7 +566,7 @@ class Count(Op):
def make_node(self, x, elem):
assert isinstance(x.type, TypedListType)
assert x.ttype == elem.type
return Apply(self, [x, elem], [T.scalar()])
return Apply(self, [x, elem], [tt.scalar()])
def perform(self, node, inputs, outputs):
"""
......@@ -611,7 +612,7 @@ class Length(Op):
def make_node(self, x):
assert isinstance(x.type, TypedListType)
return Apply(self, [x], [T.scalar(dtype="int64")])
return Apply(self, [x], [tt.scalar(dtype="int64")])
def perform(self, node, x, outputs):
(out,) = outputs
......@@ -658,7 +659,7 @@ class MakeList(Op):
a2 = []
for elem in a:
if not isinstance(elem, theano.gof.Variable):
elem = theano.tensor.as_tensor_variable(elem)
elem = tt.as_tensor_variable(elem)
a2.append(elem)
if not all(a2[0].type == elem.type for elem in a2):
raise TypeError("MakeList need all input variable to be of the same type.")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论