提交 2cb3a154 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Replace theano.tensor alias T with tt in theano.typed_list

上级 219428ba
import numpy as np import numpy as np
from .type import TypedListType
import theano import theano
import theano.tensor as tt
from theano.typed_list.type import TypedListType
from theano.gof import Apply, Constant, Op, Variable from theano.gof import Apply, Constant, Op, Variable
from theano.tensor.type_other import SliceType from theano.tensor.type_other import SliceType
from theano import tensor as T
from theano.compile.debugmode import _lessbroken_deepcopy from theano.compile.debugmode import _lessbroken_deepcopy
...@@ -74,11 +75,11 @@ class GetItem(Op): ...@@ -74,11 +75,11 @@ class GetItem(Op):
index = Constant(SliceType(), index) index = Constant(SliceType(), index)
return Apply(self, [x, index], [x.type()]) return Apply(self, [x, index], [x.type()])
else: else:
index = T.constant(index, ndim=0, dtype="int64") index = tt.constant(index, ndim=0, dtype="int64")
return Apply(self, [x, index], [x.ttype()]) return Apply(self, [x, index], [x.ttype()])
if isinstance(index.type, SliceType): if isinstance(index.type, SliceType):
return Apply(self, [x, index], [x.type()]) return Apply(self, [x, index], [x.type()])
elif isinstance(index, T.TensorVariable) and index.ndim == 0: elif isinstance(index, tt.TensorVariable) and index.ndim == 0:
assert index.dtype == "int64" assert index.dtype == "int64"
return Apply(self, [x, index], [x.ttype()]) return Apply(self, [x, index], [x.ttype()])
else: else:
...@@ -325,10 +326,10 @@ class Insert(Op): ...@@ -325,10 +326,10 @@ class Insert(Op):
assert isinstance(x.type, TypedListType) assert isinstance(x.type, TypedListType)
assert x.ttype == toInsert.type assert x.ttype == toInsert.type
if not isinstance(index, Variable): if not isinstance(index, Variable):
index = T.constant(index, ndim=0, dtype="int64") index = tt.constant(index, ndim=0, dtype="int64")
else: else:
assert index.dtype == "int64" assert index.dtype == "int64"
assert isinstance(index, T.TensorVariable) and index.ndim == 0 assert isinstance(index, tt.TensorVariable) and index.ndim == 0
return Apply(self, [x, index, toInsert], [x.type()]) return Apply(self, [x, index, toInsert], [x.type()])
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
...@@ -536,7 +537,7 @@ class Index(Op): ...@@ -536,7 +537,7 @@ class Index(Op):
def make_node(self, x, elem): def make_node(self, x, elem):
assert isinstance(x.type, TypedListType) assert isinstance(x.type, TypedListType)
assert x.ttype == elem.type assert x.ttype == elem.type
return Apply(self, [x, elem], [T.scalar()]) return Apply(self, [x, elem], [tt.scalar()])
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
""" """
...@@ -565,7 +566,7 @@ class Count(Op): ...@@ -565,7 +566,7 @@ class Count(Op):
def make_node(self, x, elem): def make_node(self, x, elem):
assert isinstance(x.type, TypedListType) assert isinstance(x.type, TypedListType)
assert x.ttype == elem.type assert x.ttype == elem.type
return Apply(self, [x, elem], [T.scalar()]) return Apply(self, [x, elem], [tt.scalar()])
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
""" """
...@@ -611,7 +612,7 @@ class Length(Op): ...@@ -611,7 +612,7 @@ class Length(Op):
def make_node(self, x): def make_node(self, x):
assert isinstance(x.type, TypedListType) assert isinstance(x.type, TypedListType)
return Apply(self, [x], [T.scalar(dtype="int64")]) return Apply(self, [x], [tt.scalar(dtype="int64")])
def perform(self, node, x, outputs): def perform(self, node, x, outputs):
(out,) = outputs (out,) = outputs
...@@ -658,7 +659,7 @@ class MakeList(Op): ...@@ -658,7 +659,7 @@ class MakeList(Op):
a2 = [] a2 = []
for elem in a: for elem in a:
if not isinstance(elem, theano.gof.Variable): if not isinstance(elem, theano.gof.Variable):
elem = theano.tensor.as_tensor_variable(elem) elem = tt.as_tensor_variable(elem)
a2.append(elem) a2.append(elem)
if not all(a2[0].type == elem.type for elem in a2): if not all(a2[0].type == elem.type for elem in a2):
raise TypeError("MakeList need all input variable to be of the same type.") raise TypeError("MakeList need all input variable to be of the same type.")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论