提交 0f344ee4 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

created BaseTensorOp in base_tensor

上级 ca315a08
......@@ -101,15 +101,15 @@ class T_abs(unittest.TestCase):
class T_fill(unittest.TestCase):
def test0(self):
t = fill(numpy.asarray([1,2,3]), 9.0)
t = fill(numpy.asarray([1,2,3]), 9)
self.failUnless(t.owner.__class__ == Fill)
o = t.owner
self.failUnless(o.inputs[0].broadcastable == (0,))
self.failUnless(o.inputs[0].dtype[0:3] == 'int')
# self.failUnless(o.inputs[0].dtype[0:3] == 'int')
self.failUnless(o.inputs[1].broadcastable == ())
self.failUnless(o.inputs[1].dtype[0:3] == 'flo')
# self.failUnless(o.inputs[1].dtype[0:3] == 'flo')
self.failUnless(o.outputs[0].broadcastable == (0,))
self.failUnless(o.outputs[0].dtype[0:3] == 'flo')
# self.failUnless(o.outputs[0].dtype[0:3] == 'flo')
class T_sum(unittest.TestCase):
def test_impl(self):
......@@ -152,7 +152,7 @@ class T_mul(unittest.TestCase):
def test_operator(self):
a = tinit([1,1])
aa = tinit([1,1])
b = tinit(4.0)
b = tinit(4)
self.failUnless(isinstance((a*b).owner, Scale))
self.failUnless(isinstance((b*a).owner, Scale))
self.failUnless(isinstance((a*aa).owner, MulElemwise))
......
"""A simple class to store ndarray data """
from gof import ResultBase
from gof import ResultBase, Op, utils
import numpy
from copy import copy
......@@ -194,3 +194,68 @@ class BaseTensor(ResultBase):
class BaseTensorOp(Op):
"""
A basic Op subclass that can be used to make Ops that operate on Tensors.
It is not mandatory to inherit from this class, but it is practical.
BasicTensorOp is parametrized as follows:
* nin: number of inputs
* nout: number of outputs
* out_tensor_class: BaseTensor subclass used to instantiate the outputs
* input_wrapper: returns a Tensor from its argument
* propagate_dtype: returns a list of dtypes corresponding to the
output dtypes from a list of input dtypes (if an input is
not a Tensor, the passed value will be None)
* propagate_broadcastable: returns a list of tuples corresponding to
the output broadcastable flags from the input broadcastable
flags (if an input is not a Tensor, the passed value will be
None).
"""
nin = -1 # nin == -1 means: arbitrary number of inputs
nout = 1
out_tensor_class = BaseTensor
@classmethod
def input_wrapper(cls, obj):
"""
Returns a Result from an arbitrary-typed input, if possible.
"""
if isinstance(obj, BaseResult):
return obj
else:
raise TypeError("Expected a Result instance.")
def __init__(self, *inputs):
inputs = map(self.input_wrapper, inputs)
if self.nin >= 0:
if len(inputs) != self.nin:
raise TypeError("Wrong number of inputs for %s (got %i, expected %i)") \
% (self, len(inputs), self.nin)
i_broadcastables = [getattr(input, 'broadcastable', None) for input in inputs]
i_dtypes = [getattr(input, 'dtype', None) for input in inputs]
o_broadcastables = utils.from_return_values(self.propagate_broadcastable(*i_broadcastables))
o_dtypes = utils.from_return_values(self.propagate_dtype(*i_dtypes))
self.inputs = inputs
self.outputs = [self.out_tensor_class(dtype, broadcastable) for broadcastable, dtype in zip(o_broadcastables, o_dtypes)]
def propagate_broadcastable(self, *inputs):
raise AbstractFunctionError()
def propagate_dtype(self, *i_dtypes):
rval = set([dtype for dtype in i_dtypes if dtype is not None])
if len(rval) == 0:
raise ValueError("Cannot infer the dtypes of the outputs with no Tensor inputs.")
elif len(rval) > 1:
raise ValueError("The dtypes of all inputs should be identical.")
return [rval.pop()] * self.nout
......@@ -78,7 +78,7 @@ class Elemwise(Op):
return code_cleanup
@classmethod
def inplace_version(cls):
def inplace_version(cls, dmap = {0:0}):
class Ret(cls, Destroyer):
def destroy_map(self):
return {self.outputs[0]: [self.inputs[0]]}
......
......@@ -165,7 +165,21 @@ class _test_MergeOptimizer(unittest.TestCase):
assert str(g) == "[Op1(*1 -> Op2(x, y), *1, *1)]" \
or str(g) == "[Op1(*1 -> Op2(x, z), *1, *1)]"
def test_2(self):
class _test_ConstantFinder(unittest.TestCase):
def test_0(self):
x, y, z = inputs()
y.data = 2
z.data = 2
e = op1(x, y, z)
g = env([x], [e])
ConstantFinder().optimize(g)
MergeOptimizer().optimize(g)
assert str(g) == "[Op1(x, y, y)]" \
or str(g) == "[Op1(x, z, z)]"
def test_1(self):
x, y, z = inputs()
y.data = 2
z.data = 2
......
......@@ -5,7 +5,7 @@ from copy import copy
import inspect
from gof import ResultBase, Op, utils, Destroyer, Viewer, AbstractFunctionError
from base_tensor import BaseTensor
from base_tensor import BaseTensor, BaseTensorOp
from elemwise import Elemwise
......@@ -114,64 +114,45 @@ def _assert_tensor_scalar(x, a):
raise ValueError("The second argument must be a scalar.")
class _Op(Op):
class _Op(BaseTensorOp):
"""A convenient base for the ops in this file"""
nin = -1
nout = 1
_destroy_map = {}
out_tensor_class = Tensor
def __init__(self, *inputs):
def as_tensor(obj):
if isinstance(obj, Tensor):
return obj
else:
return tinit(obj)
inputs = map(as_tensor, inputs)
if self.nin >= 0:
if len(inputs) != self.nin:
raise TypeError("Wrong number of inputs for %s (got %i, expected %i)") \
% (self, len(inputs), self.nin)
i_broadcastables = [getattr(input, 'broadcastable', None) for input in inputs]
i_dtypes = [getattr(input, 'dtype', None) for input in inputs]
o_broadcastables = utils.from_return_values(self.propagate_broadcastable(*i_broadcastables))
o_dtypes = utils.from_return_values(self.propagate_dtype(*i_dtypes))
self.inputs = inputs
self.outputs = [Tensor(dtype, broadcastable) for broadcastable, dtype in zip(o_broadcastables, o_dtypes)]
def propagate_broadcastable(self, *inputs):
raise AbstractFunctionError()
@classmethod
def input_wrapper(cls, obj):
if isinstance(obj, Tensor):
return obj
else:
return tinit(obj)
# nin = -1
# nout = 1
def propagate_dtype(self, *i_dtypes):
def upcast(dtype, *dtypes):
z = numpy.zeros((), dtype = dtype)
for dtype in dtypes:
z = z + numpy.zeros((), dtype = dtype)
return str(z.dtype)
for dtype in i_dtypes:
if dtype is None:
raise TypeError("Expected a Tensor.")
upcasted = upcast(*i_dtypes)
return [upcasted] * self.nout
# try:
# dmap = self.destroy_map()
# except AttributeError:
# dmap = {}
# rval = []
# for i in xrange(self.nout):
# if i in dmap:
# destroyed = dmap[output]
# if len(destroyed) != 1:
# raise TypeError("Cannot infer dtype of output %s because it destroys more than one input." % output)
# rval.append(destroyed[0])
# else:
# rval.append(upcasted)
# return rval
# def upcast(dtype, *dtypes):
# z = numpy.zeros((), dtype = dtype)
# for dtype in dtypes:
# z = z + numpy.zeros((), dtype = dtype)
# return str(z.dtype)
# for dtype in i_dtypes:
# if dtype is None:
# raise TypeError("Expected a Tensor.")
# upcasted = upcast(*i_dtypes)
# return [upcasted] * self.nout
# # try:
# # dmap = self.destroy_map()
# # except AttributeError:
# # dmap = {}
# # rval = []
# # for i in xrange(self.nout):
# # if i in dmap:
# # destroyed = dmap[output]
# # if len(destroyed) != 1:
# # raise TypeError("Cannot infer dtype of output %s because it destroys more than one input." % output)
# # rval.append(destroyed[0])
# # else:
# # rval.append(upcasted)
# # return rval
def impl(self, *inputs):
raise AbstractFunctionError()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论