提交 8b282565 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Thomas Wiecki

Move theano.tensor.basic_opt.MakeVector and make_vector to theano.tensor.basic

上级 5f7dfd62
...@@ -33,8 +33,7 @@ from theano.gpuarray.basic_ops import ( ...@@ -33,8 +33,7 @@ from theano.gpuarray.basic_ops import (
from theano.gpuarray.elemwise import GpuDimShuffle, GpuElemwise from theano.gpuarray.elemwise import GpuDimShuffle, GpuElemwise
from theano.gpuarray.subtensor import GpuSubtensor from theano.gpuarray.subtensor import GpuSubtensor
from theano.gpuarray.type import GpuArrayType, get_context, gpuarray_shared_constructor from theano.gpuarray.type import GpuArrayType, get_context, gpuarray_shared_constructor
from theano.tensor.basic import Alloc, Split, alloc from theano.tensor.basic import Alloc, MakeVector, Split, alloc
from theano.tensor.basic_opt import MakeVector
from theano.tensor.shape import Shape, Shape_i from theano.tensor.shape import Shape, Shape_i
from theano.tensor.type import TensorType, fmatrix, iscalar, lscalar, matrix from theano.tensor.type import TensorType, fmatrix, iscalar, lscalar, matrix
......
...@@ -34,7 +34,7 @@ from theano.gpuarray.linalg import GpuCholesky, GpuCusolverSolve, cusolver_avail ...@@ -34,7 +34,7 @@ from theano.gpuarray.linalg import GpuCholesky, GpuCusolverSolve, cusolver_avail
from theano.gpuarray.subtensor import GpuSubtensor from theano.gpuarray.subtensor import GpuSubtensor
from theano.gpuarray.type import GpuArrayType, get_context, gpuarray_shared_constructor from theano.gpuarray.type import GpuArrayType, get_context, gpuarray_shared_constructor
from theano.graph.opt import check_stack_trace from theano.graph.opt import check_stack_trace
from theano.tensor.basic import Alloc, AllocEmpty, Rebroadcast from theano.tensor.basic import Alloc, AllocEmpty, MakeVector, Rebroadcast
from theano.tensor.blas import batched_dot from theano.tensor.blas import batched_dot
from theano.tensor.math import dot, eq, exp, gt, tanh from theano.tensor.math import dot, eq, exp, gt, tanh
from theano.tensor.nnet import abstract_conv from theano.tensor.nnet import abstract_conv
...@@ -68,7 +68,7 @@ def _check_stack_trace(thing): ...@@ -68,7 +68,7 @@ def _check_stack_trace(thing):
Shape_i, Shape_i,
Shape, Shape,
theano.compile.ops.DeepCopyOp, theano.compile.ops.DeepCopyOp,
theano.tensor.basic_opt.MakeVector, MakeVector,
theano.tensor.subtensor.Subtensor, theano.tensor.subtensor.Subtensor,
theano.tensor.elemwise.Elemwise, theano.tensor.elemwise.Elemwise,
theano.ifelse.IfElse, theano.ifelse.IfElse,
......
...@@ -15,9 +15,7 @@ from theano.graph.optdb import Query ...@@ -15,9 +15,7 @@ from theano.graph.optdb import Query
from theano.ifelse import ifelse from theano.ifelse import ifelse
from theano.link.jax import JAXLinker from theano.link.jax import JAXLinker
from theano.scan.basic import scan from theano.scan.basic import scan
from theano.tensor import basic
from theano.tensor import basic as tt from theano.tensor import basic as tt
from theano.tensor import basic_opt as tt_opt
from theano.tensor import blas as tt_blas from theano.tensor import blas as tt_blas
from theano.tensor import elemwise as tt_elemwise from theano.tensor import elemwise as tt_elemwise
from theano.tensor import extra_ops as tt_extra_ops from theano.tensor import extra_ops as tt_extra_ops
...@@ -184,15 +182,13 @@ def test_jax_compile_ops(): ...@@ -184,15 +182,13 @@ def test_jax_compile_ops():
compare_jax_and_py(x_fg, []) compare_jax_and_py(x_fg, [])
x_np = np.zeros((20, 1, 1)) x_np = np.zeros((20, 1, 1))
x = basic.Rebroadcast((0, False), (1, True), (2, False))( x = tt.Rebroadcast((0, False), (1, True), (2, False))(tt.as_tensor_variable(x_np))
tt.as_tensor_variable(x_np)
)
x_fg = FunctionGraph([], [x]) x_fg = FunctionGraph([], [x])
compare_jax_and_py(x_fg, []) compare_jax_and_py(x_fg, [])
with config.change_flags(compute_test_value="off"): with config.change_flags(compute_test_value="off"):
x = basic.Rebroadcast((0, True), (1, False), (2, False))( x = tt.Rebroadcast((0, True), (1, False), (2, False))(
tt.as_tensor_variable(x_np) tt.as_tensor_variable(x_np)
) )
x_fg = FunctionGraph([], [x]) x_fg = FunctionGraph([], [x])
...@@ -654,7 +650,7 @@ def test_jax_CAReduce(): ...@@ -654,7 +650,7 @@ def test_jax_CAReduce():
def test_jax_MakeVector(): def test_jax_MakeVector():
x = tt_opt.make_vector(1, 2, 3) x = tt.make_vector(1, 2, 3)
x_fg = FunctionGraph([], [x]) x_fg = FunctionGraph([], [x])
compare_jax_and_py(x_fg, []) compare_jax_and_py(x_fg, [])
......
...@@ -86,6 +86,7 @@ from theano.sparse.basic import ( ...@@ -86,6 +86,7 @@ from theano.sparse.basic import (
_mtypes, _mtypes,
) )
from theano.sparse.opt import CSMGradC, StructuredDotCSC, UsmmCscDense from theano.sparse.opt import CSMGradC, StructuredDotCSC, UsmmCscDense
from theano.tensor.basic import MakeVector
from theano.tensor.elemwise import DimShuffle, Elemwise from theano.tensor.elemwise import DimShuffle, Elemwise
from theano.tensor.math import sum as tt_sum from theano.tensor.math import sum as tt_sum
from theano.tensor.shape import Shape_i from theano.tensor.shape import Shape_i
...@@ -1876,7 +1877,7 @@ def test_shape(): ...@@ -1876,7 +1877,7 @@ def test_shape():
assert len(topo) == 3 assert len(topo) == 3
assert isinstance(topo[0].op, Shape_i) assert isinstance(topo[0].op, Shape_i)
assert isinstance(topo[1].op, Shape_i) assert isinstance(topo[1].op, Shape_i)
assert isinstance(topo[2].op, theano.tensor.basic_opt.MakeVector) assert isinstance(topo[2].op, MakeVector)
def test_may_share_memory(): def test_may_share_memory():
......
...@@ -50,6 +50,7 @@ from theano.tensor.basic import ( ...@@ -50,6 +50,7 @@ from theano.tensor.basic import (
ExtractDiag, ExtractDiag,
Eye, Eye,
Join, Join,
MakeVector,
PermuteRowElements, PermuteRowElements,
ScalarFromTensor, ScalarFromTensor,
Split, Split,
...@@ -75,6 +76,7 @@ from theano.tensor.basic import ( ...@@ -75,6 +76,7 @@ from theano.tensor.basic import (
horizontal_stack, horizontal_stack,
inverse_permutation, inverse_permutation,
join, join,
make_vector,
mgrid, mgrid,
nonzero, nonzero,
nonzero_values, nonzero_values,
...@@ -99,7 +101,6 @@ from theano.tensor.basic import ( ...@@ -99,7 +101,6 @@ from theano.tensor.basic import (
vertical_stack, vertical_stack,
zeros_like, zeros_like,
) )
from theano.tensor.basic_opt import MakeVector, make_vector
from theano.tensor.elemwise import DimShuffle from theano.tensor.elemwise import DimShuffle
from theano.tensor.exceptions import EmptyConstantError, NotScalarConstantError from theano.tensor.exceptions import EmptyConstantError, NotScalarConstantError
from theano.tensor.math import dense_dot, eq from theano.tensor.math import dense_dot, eq
......
...@@ -34,6 +34,7 @@ from theano.tensor import inplace ...@@ -34,6 +34,7 @@ from theano.tensor import inplace
from theano.tensor.basic import ( from theano.tensor.basic import (
Alloc, Alloc,
Join, Join,
MakeVector,
Rebroadcast, Rebroadcast,
ScalarFromTensor, ScalarFromTensor,
Split, Split,
...@@ -41,10 +42,10 @@ from theano.tensor.basic import ( ...@@ -41,10 +42,10 @@ from theano.tensor.basic import (
_convert_to_int8, _convert_to_int8,
as_tensor_variable, as_tensor_variable,
join, join,
make_vector,
tile, tile,
) )
from theano.tensor.basic_opt import ( from theano.tensor.basic_opt import (
MakeVector,
ShapeFeature, ShapeFeature,
assert_op, assert_op,
local_canonicalize_alloc, local_canonicalize_alloc,
...@@ -55,7 +56,6 @@ from theano.tensor.basic_opt import ( ...@@ -55,7 +56,6 @@ from theano.tensor.basic_opt import (
local_useless_dimshuffle_in_reshape, local_useless_dimshuffle_in_reshape,
local_useless_elemwise, local_useless_elemwise,
local_useless_reshape, local_useless_reshape,
make_vector,
register_specialize, register_specialize,
) )
from theano.tensor.blas import Dot22, Gemv from theano.tensor.blas import Dot22, Gemv
......
...@@ -10,8 +10,8 @@ from theano.compile.ops import DeepCopyOp ...@@ -10,8 +10,8 @@ from theano.compile.ops import DeepCopyOp
from theano.configdefaults import config from theano.configdefaults import config
from theano.graph.fg import FunctionGraph from theano.graph.fg import FunctionGraph
from theano.misc.safe_asarray import _asarray from theano.misc.safe_asarray import _asarray
from theano.tensor.basic import as_tensor_variable, constant from theano.tensor.basic import MakeVector, as_tensor_variable, constant
from theano.tensor.basic_opt import MakeVector, ShapeFeature from theano.tensor.basic_opt import ShapeFeature
from theano.tensor.elemwise import DimShuffle, Elemwise from theano.tensor.elemwise import DimShuffle, Elemwise
from theano.tensor.shape import ( from theano.tensor.shape import (
Reshape, Reshape,
......
...@@ -8,7 +8,7 @@ import theano.sparse ...@@ -8,7 +8,7 @@ import theano.sparse
import theano.tensor as tt import theano.tensor as tt
from tests import unittest_tools as utt from tests import unittest_tools as utt
from theano.misc.may_share_memory import may_share_memory from theano.misc.may_share_memory import may_share_memory
from theano.tensor.basic_opt import MakeVector from theano.tensor.basic import MakeVector
from theano.tensor.shape import Shape_i, specify_shape from theano.tensor.shape import Shape_i, specify_shape
......
...@@ -20,11 +20,11 @@ from theano.tensor.basic import ( ...@@ -20,11 +20,11 @@ from theano.tensor.basic import (
ARange, ARange,
Eye, Eye,
Join, Join,
MakeVector,
Rebroadcast, Rebroadcast,
ScalarFromTensor, ScalarFromTensor,
TensorFromScalar, TensorFromScalar,
) )
from theano.tensor.basic_opt import MakeVector
from theano.tensor.blas import BatchedDot from theano.tensor.blas import BatchedDot
from theano.tensor.elemwise import CAReduce, DimShuffle, Elemwise from theano.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from theano.tensor.extra_ops import ( from theano.tensor.extra_ops import (
......
...@@ -183,7 +183,7 @@ def as_tensor_variable(x, name=None, ndim=None): ...@@ -183,7 +183,7 @@ def as_tensor_variable(x, name=None, ndim=None):
# `MakeVector` is a better option due to its `get_scalar_constant_value` # `MakeVector` is a better option due to its `get_scalar_constant_value`
# support. # support.
dtype = ts.upcast(*[i.dtype for i in x if hasattr(i, "dtype")]) dtype = ts.upcast(*[i.dtype for i in x if hasattr(i, "dtype")])
return theano.tensor.basic_opt.MakeVector(dtype)(*x) return MakeVector(dtype)(*x)
return stack(x) return stack(x)
...@@ -500,9 +500,7 @@ def get_scalar_constant_value( ...@@ -500,9 +500,7 @@ def get_scalar_constant_value(
elif ( elif (
v.owner.inputs[0].owner v.owner.inputs[0].owner
and isinstance( and isinstance(v.owner.inputs[0].owner.op, MakeVector)
v.owner.inputs[0].owner.op, theano.tensor.basic_opt.MakeVector
)
and and
# MakeVector normally accept only scalar as input. # MakeVector normally accept only scalar as input.
# We put this check in case there is change in the future # We put this check in case there is change in the future
...@@ -1589,6 +1587,118 @@ alloc = Alloc() ...@@ -1589,6 +1587,118 @@ alloc = Alloc()
pprint.assign(alloc, printing.FunctionPrinter("alloc")) pprint.assign(alloc, printing.FunctionPrinter("alloc"))
class MakeVector(COp):
"""Concatenate a number of scalars together into a vector.
This is a simple version of stack() that introduces far less cruft
into the graph. Should work with 0 inputs. The constant_folding
optimization will remove it.
"""
__props__ = ("dtype",)
def __init__(self, dtype="int64"):
self.dtype = dtype
def make_node(self, *inputs):
inputs = list(map(as_tensor_variable, inputs))
if not all(a.type == inputs[0].type for a in inputs) or (
len(inputs) > 0 and inputs[0].dtype != self.dtype
):
dtype = ts.upcast(self.dtype, *[i.dtype for i in inputs])
# upcast the input to the determined dtype,
# but don't downcast anything
assert dtype == self.dtype, (
"The upcast of the inputs to MakeVector should match the "
"dtype given in __init__."
)
if not all(self.dtype == cast(i, dtype=dtype).dtype for i in inputs):
raise TypeError(
"MakeVector.make_node expected inputs"
f" upcastable to {self.dtype}. got {[i.dtype for i in inputs]}"
)
inputs = [cast(i, dtype=dtype) for i in inputs]
assert all(self.dtype == a.dtype for a in inputs)
assert all(a.ndim == 0 for a in inputs)
if inputs:
dtype = inputs[0].type.dtype
else:
dtype = self.dtype
# bcastable = (len(inputs) == 1)
bcastable = False
otype = TensorType(broadcastable=(bcastable,), dtype=dtype)
return Apply(self, inputs, [otype()])
def perform(self, node, inputs, out_):
(out,) = out_
# not calling theano._asarray as optimization
if (out[0] is None) or (out[0].size != len(inputs)):
out[0] = _asarray(inputs, dtype=node.outputs[0].dtype)
else:
# assume that out has correct dtype. there is no cheap way to check
out[0][...] = inputs
def c_code_cache_version(self):
return (2,)
def c_code(self, node, name, inp, out_, props):
(out,) = out_
# Shouldn't use PyArray_TYPE(inp[0]) for the dtype
# when len(inp) == 0 (we need to support this case.
# So there will be (1 * nb_dtype) + ((nb len(inp) - 1 ))
# different c code with the following algo
out_shape = len(inp)
out_num = np.dtype(node.outputs[0].dtype).num
# don't use dtype_%(out)s as when check_input=False, it isn't defined.
out_dtype = node.outputs[0].type.dtype_specs()[1]
if len(inp) > 0:
assert self.dtype == node.inputs[0].dtype
out_num = f"PyArray_TYPE({inp[0]})"
ret = (
"""
npy_intp dims[1];
dims[0] = %(out_shape)s;
if(!%(out)s || PyArray_DIMS(%(out)s)[0] != %(out_shape)s){
Py_XDECREF(%(out)s);
%(out)s = (PyArrayObject*)PyArray_EMPTY(1, dims, %(out_num)s, 0);
}
"""
% locals()
)
for idx, i in enumerate(inp):
ret += (
"""
*((%(out_dtype)s *)PyArray_GETPTR1(%(out)s, %(idx)s)) = *((%(out_dtype)s *) PyArray_DATA(%(i)s));
"""
% locals()
)
return ret
def infer_shape(self, fgraph, node, ishapes):
return [(len(ishapes),)]
def grad(self, inputs, output_gradients):
# If the output is of an integer dtype, no gradient shall pass
if self.dtype in discrete_dtypes:
return [ipt.zeros_like().astype(config.floatX) for ipt in inputs]
grads = []
for i, inp in enumerate(inputs):
grads.append(output_gradients[0][i])
return grads
def R_op(self, inputs, eval_points):
if None in eval_points:
return [None]
return self.make_node(*eval_points).outputs
make_vector = MakeVector()
def transfer(var, target): def transfer(var, target):
""" """
Return a version of `var` transferred to `target`. Return a version of `var` transferred to `target`.
......
...@@ -20,7 +20,6 @@ from theano.compile.ops import ViewOp ...@@ -20,7 +20,6 @@ from theano.compile.ops import ViewOp
from theano.configdefaults import config from theano.configdefaults import config
from theano.graph import toolbox from theano.graph import toolbox
from theano.graph.basic import ( from theano.graph.basic import (
Apply,
Constant, Constant,
Variable, Variable,
ancestors, ancestors,
...@@ -28,7 +27,7 @@ from theano.graph.basic import ( ...@@ -28,7 +27,7 @@ from theano.graph.basic import (
io_toposort, io_toposort,
) )
from theano.graph.fg import InconsistencyError from theano.graph.fg import InconsistencyError
from theano.graph.op import COp, get_test_value from theano.graph.op import get_test_value
from theano.graph.opt import ( from theano.graph.opt import (
GlobalOptimizer, GlobalOptimizer,
OpRemove, OpRemove,
...@@ -44,7 +43,6 @@ from theano.graph.utils import ( ...@@ -44,7 +43,6 @@ from theano.graph.utils import (
TestValueError, TestValueError,
get_variable_trace_string, get_variable_trace_string,
) )
from theano.misc.safe_asarray import _asarray
from theano.printing import pprint from theano.printing import pprint
from theano.tensor.basic import ( from theano.tensor.basic import (
Alloc, Alloc,
...@@ -52,6 +50,7 @@ from theano.tensor.basic import ( ...@@ -52,6 +50,7 @@ from theano.tensor.basic import (
ARange, ARange,
Flatten, Flatten,
Join, Join,
MakeVector,
Rebroadcast, Rebroadcast,
ScalarFromTensor, ScalarFromTensor,
Split, Split,
...@@ -66,6 +65,7 @@ from theano.tensor.basic import ( ...@@ -66,6 +65,7 @@ from theano.tensor.basic import (
get_scalar_constant_value, get_scalar_constant_value,
get_vector_length, get_vector_length,
join, join,
make_vector,
ones_like, ones_like,
patternbroadcast, patternbroadcast,
switch, switch,
...@@ -778,120 +778,6 @@ def local_scalar_tensor_scalar(fgraph, node): ...@@ -778,120 +778,6 @@ def local_scalar_tensor_scalar(fgraph, node):
##################################### #####################################
# ShapeFeature, Shape optimizations # ShapeFeature, Shape optimizations
##################################### #####################################
class MakeVector(COp):
"""Concatenate a number of scalars together into a vector.
This is a simple version of stack() that introduces far less cruft
into the graph. Should work with 0 inputs. The constant_folding
optimization will remove it.
"""
__props__ = ("dtype",)
def __init__(self, dtype="int64"):
self.dtype = dtype
def make_node(self, *inputs):
inputs = list(map(as_tensor_variable, inputs))
if not all(a.type == inputs[0].type for a in inputs) or (
len(inputs) > 0 and inputs[0].dtype != self.dtype
):
dtype = ts.upcast(self.dtype, *[i.dtype for i in inputs])
# upcast the input to the determined dtype,
# but don't downcast anything
assert dtype == self.dtype, (
"The upcast of the inputs to MakeVector should match the "
"dtype given in __init__."
)
if not all(self.dtype == cast(i, dtype=dtype).dtype for i in inputs):
raise TypeError(
"MakeVector.make_node expected inputs"
f" upcastable to {self.dtype}. got {[i.dtype for i in inputs]}"
)
inputs = [cast(i, dtype=dtype) for i in inputs]
assert all(self.dtype == a.dtype for a in inputs)
assert all(a.ndim == 0 for a in inputs)
if inputs:
dtype = inputs[0].type.dtype
else:
dtype = self.dtype
# bcastable = (len(inputs) == 1)
bcastable = False
otype = TensorType(broadcastable=(bcastable,), dtype=dtype)
return Apply(self, inputs, [otype()])
def perform(self, node, inputs, out_):
(out,) = out_
# not calling theano._asarray as optimization
if (out[0] is None) or (out[0].size != len(inputs)):
out[0] = _asarray(inputs, dtype=node.outputs[0].dtype)
else:
# assume that out has correct dtype. there is no cheap way to check
out[0][...] = inputs
def c_code_cache_version(self):
return (2,)
def c_code(self, node, name, inp, out_, props):
(out,) = out_
# Shouldn't use PyArray_TYPE(inp[0]) for the dtype
# when len(inp) == 0 (we need to support this case.
# So there will be (1 * nb_dtype) + ((nb len(inp) - 1 ))
# different c code with the following algo
out_shape = len(inp)
out_num = np.dtype(node.outputs[0].dtype).num
# don't use dtype_%(out)s as when check_input=False, it isn't defined.
out_dtype = node.outputs[0].type.dtype_specs()[1]
if len(inp) > 0:
assert self.dtype == node.inputs[0].dtype
out_num = f"PyArray_TYPE({inp[0]})"
ret = (
"""
npy_intp dims[1];
dims[0] = %(out_shape)s;
if(!%(out)s || PyArray_DIMS(%(out)s)[0] != %(out_shape)s){
Py_XDECREF(%(out)s);
%(out)s = (PyArrayObject*)PyArray_EMPTY(1, dims, %(out_num)s, 0);
}
"""
% locals()
)
for idx, i in enumerate(inp):
ret += (
"""
*((%(out_dtype)s *)PyArray_GETPTR1(%(out)s, %(idx)s)) = *((%(out_dtype)s *) PyArray_DATA(%(i)s));
"""
% locals()
)
return ret
def infer_shape(self, fgraph, node, ishapes):
return [(len(ishapes),)]
def grad(self, inputs, output_gradients):
# If the output is of an integer dtype, no gradient shall pass
if self.dtype in discrete_dtypes:
return [ipt.zeros_like().astype(config.floatX) for ipt in inputs]
grads = []
for i, inp in enumerate(inputs):
grads.append(output_gradients[0][i])
return grads
def R_op(self, inputs, eval_points):
if None in eval_points:
return [None]
return self.make_node(*eval_points).outputs
make_vector = MakeVector()
class MakeVectorPrinter: class MakeVectorPrinter:
def process(self, r, pstate): def process(self, r, pstate):
if r.owner is None: if r.owner is None:
......
...@@ -27,6 +27,7 @@ from theano.misc.safe_asarray import _asarray ...@@ -27,6 +27,7 @@ from theano.misc.safe_asarray import _asarray
from theano.tensor.basic import ( from theano.tensor.basic import (
Alloc, Alloc,
Join, Join,
MakeVector,
alloc, alloc,
as_tensor_variable, as_tensor_variable,
cast, cast,
...@@ -40,7 +41,6 @@ from theano.tensor.basic import ( ...@@ -40,7 +41,6 @@ from theano.tensor.basic import (
) )
from theano.tensor.basic_opt import ( from theano.tensor.basic_opt import (
FusionOptimizer, FusionOptimizer,
MakeVector,
_fill_chain, _fill_chain,
broadcast_like, broadcast_like,
encompasses_broadcastable, encompasses_broadcastable,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论