Renamed tinit to astensor

上级 a68aa7e2
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import unittest import unittest
import numpy import numpy
from tensor import tinit, Tensor from tensor import astensor, Tensor
import gof import gof
from gof import modes, Env from gof import modes, Env
...@@ -29,9 +29,9 @@ def inputs(): ...@@ -29,9 +29,9 @@ def inputs():
l1 = [[1.0, 2.0], [3.0, 4.0]] l1 = [[1.0, 2.0], [3.0, 4.0]]
l2 = [[3.0, 4.0], [1.0, 2.0]] l2 = [[3.0, 4.0], [1.0, 2.0]]
l3 = numpy.ones((2, 3)) l3 = numpy.ones((2, 3))
x = modes.build(tinit(l1, name = 'x')) x = modes.build(astensor(l1, name = 'x'))
y = modes.build(tinit(l2, name = 'y')) y = modes.build(astensor(l2, name = 'y'))
z = modes.build(tinit(l3, name = 'z')) z = modes.build(astensor(l3, name = 'z'))
return x, y, z return x, y, z
def env(inputs, outputs, validate = True, features = []): def env(inputs, outputs, validate = True, features = []):
......
差异被折叠。
...@@ -67,7 +67,7 @@ class Tensor(BaseTensor): ...@@ -67,7 +67,7 @@ class Tensor(BaseTensor):
def __getslice__(self, *args): return subtensor(self, slice(*args)) def __getslice__(self, *args): return subtensor(self, slice(*args))
# alternate Tensor constructor # alternate Tensor constructor
def tinit(data, broadcastable=None, role=None, name=None): def astensor(data, broadcastable=None, role=None, name=None):
"""Return a Tensor containing given data""" """Return a Tensor containing given data"""
data = numpy.asarray(data) data = numpy.asarray(data)
if broadcastable is None: if broadcastable is None:
...@@ -90,7 +90,7 @@ def _scalar_switch(normal_f, scalar_f, scalar_f_reverse = None): ...@@ -90,7 +90,7 @@ def _scalar_switch(normal_f, scalar_f, scalar_f_reverse = None):
if isinstance(obj, Tensor): if isinstance(obj, Tensor):
return obj return obj
else: else:
return tinit(obj) return astensor(obj)
x, y = as_tensor(x), as_tensor(y) x, y = as_tensor(x), as_tensor(y)
if 0 not in y.broadcastable: if 0 not in y.broadcastable:
return scalar_f(x, y) return scalar_f(x, y)
...@@ -119,7 +119,7 @@ def _as_tensor(obj): ...@@ -119,7 +119,7 @@ def _as_tensor(obj):
if isinstance(obj, Tensor): if isinstance(obj, Tensor):
return obj return obj
else: else:
return tinit(obj) return astensor(obj)
class _Op(BaseTensorOp): class _Op(BaseTensorOp):
"""A convenient base for the ops in this file""" """A convenient base for the ops in this file"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论