提交 9eb77476 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Thomas Wiecki

Move theano.compile.ops.Rebroadcast to theano.tensor.basic

上级 f137ba7c
...@@ -4,8 +4,9 @@ import numpy as np ...@@ -4,8 +4,9 @@ import numpy as np
from tests import unittest_tools as utt from tests import unittest_tools as utt
from theano import function from theano import function
from theano.compile.ops import Rebroadcast, as_op from theano.compile.ops import as_op
from theano.configdefaults import config from theano.configdefaults import config
from theano.tensor.basic import Rebroadcast
from theano.tensor.type import TensorType, dmatrix, dtensor4, dvector from theano.tensor.type import TensorType, dmatrix, dtensor4, dvector
......
...@@ -6,7 +6,7 @@ import pytest ...@@ -6,7 +6,7 @@ import pytest
import theano.scalar.basic as ts import theano.scalar.basic as ts
from theano.compile.function import function from theano.compile.function import function
from theano.compile.mode import Mode from theano.compile.mode import Mode
from theano.compile.ops import DeepCopyOp, Rebroadcast, ViewOp from theano.compile.ops import DeepCopyOp, ViewOp
from theano.compile.sharedvalue import shared from theano.compile.sharedvalue import shared
from theano.configdefaults import config from theano.configdefaults import config
from theano.graph.fg import FunctionGraph from theano.graph.fg import FunctionGraph
...@@ -15,6 +15,7 @@ from theano.graph.optdb import Query ...@@ -15,6 +15,7 @@ from theano.graph.optdb import Query
from theano.ifelse import ifelse from theano.ifelse import ifelse
from theano.link.jax import JAXLinker from theano.link.jax import JAXLinker
from theano.scan.basic import scan from theano.scan.basic import scan
from theano.tensor import basic
from theano.tensor import basic as tt from theano.tensor import basic as tt
from theano.tensor import blas as tt_blas from theano.tensor import blas as tt_blas
from theano.tensor import elemwise as tt_elemwise from theano.tensor import elemwise as tt_elemwise
...@@ -183,13 +184,17 @@ def test_jax_compile_ops(): ...@@ -183,13 +184,17 @@ def test_jax_compile_ops():
compare_jax_and_py(x_fg, []) compare_jax_and_py(x_fg, [])
x_np = np.zeros((20, 1, 1)) x_np = np.zeros((20, 1, 1))
x = Rebroadcast((0, False), (1, True), (2, False))(tt.as_tensor_variable(x_np)) x = basic.Rebroadcast((0, False), (1, True), (2, False))(
tt.as_tensor_variable(x_np)
)
x_fg = FunctionGraph([], [x]) x_fg = FunctionGraph([], [x])
compare_jax_and_py(x_fg, []) compare_jax_and_py(x_fg, [])
with config.change_flags(compute_test_value="off"): with config.change_flags(compute_test_value="off"):
x = Rebroadcast((0, True), (1, False), (2, False))(tt.as_tensor_variable(x_np)) x = basic.Rebroadcast((0, True), (1, False), (2, False))(
tt.as_tensor_variable(x_np)
)
x_fg = FunctionGraph([], [x]) x_fg = FunctionGraph([], [x])
with pytest.raises(ValueError): with pytest.raises(ValueError):
......
...@@ -51,12 +51,10 @@ from theano.compile.monitormode import MonitorMode ...@@ -51,12 +51,10 @@ from theano.compile.monitormode import MonitorMode
from theano.compile.ops import ( from theano.compile.ops import (
DeepCopyOp, DeepCopyOp,
FromFunctionOp, FromFunctionOp,
Rebroadcast,
ViewOp, ViewOp,
as_op, as_op,
deep_copy_op, deep_copy_op,
register_deep_copy_op_c_code, register_deep_copy_op_c_code,
register_rebroadcast_c_code,
register_view_op_c_code, register_view_op_c_code,
view_op, view_op,
) )
......
...@@ -8,11 +8,7 @@ help make new Ops more rapidly. ...@@ -8,11 +8,7 @@ help make new Ops more rapidly.
import copy import copy
import pickle import pickle
import warnings import warnings
from collections import OrderedDict
import numpy as np
import theano
from theano.graph.basic import Apply from theano.graph.basic import Apply
from theano.graph.op import COp, Op from theano.graph.op import COp, Op
from theano.graph.type import CType from theano.graph.type import CType
...@@ -333,177 +329,3 @@ def as_op(itypes, otypes, infer_shape=None): ...@@ -333,177 +329,3 @@ def as_op(itypes, otypes, infer_shape=None):
return FromFunctionOp(fn, itypes, otypes, infer_shape) return FromFunctionOp(fn, itypes, otypes, infer_shape)
return make_op return make_op
def register_rebroadcast_c_code(typ, code, version=()):
"""
Tell Rebroadcast how to generate C code for a Theano Type.
typ : Theano type
It must be the Theano class itself and not an instance of the class.
code : C code
That checks if the dimension %(axis)s is of shape 1 for the Theano type
'typ'. Use %(iname)s and %(oname)s for the input and output C variable
names respectively, and %(axis)s for the axis that we need to check.
This code is put in a loop for all axes.
version
A number indicating the version of the code, for cache.
"""
Rebroadcast.c_code_and_version[typ] = (code, version)
class Rebroadcast(COp):
"""
Change the input's broadcastable fields in some predetermined way.
See Also
--------
unbroadcast <theano.tensor.unbroadcast>
addbroadcast <theano.tensor.addbroadcast>
patternbroadcast <theano.tensor.patternbroadcast>
Notes
-----
Works inplace and works for CudaNdarrayType.
Examples
--------
`Rebroadcast((0, True), (1, False))(x)` would make `x` broadcastable in
axis 0 and not broadcastable in axis 1.
"""
view_map = {0: [0]}
_f16_ok = True
# Mapping from Type to C code (and version) to use.
# In the C code, the name of the input variable is %(iname)s,
# the output variable is %(oname)s.
c_code_and_version = {}
check_input = False
__props__ = ("axis",)
_f16_ok = True
def __init__(self, *axis):
# Sort them to make sure we merge all possible case.
items = sorted(axis)
self.axis = OrderedDict(items)
for axis, broad in self.axis.items():
if not isinstance(axis, (np.integer, int)):
raise TypeError(f"Rebroadcast needs integer axes. Got {axis}")
if not isinstance(broad, (np.bool_, bool)):
raise TypeError(
f"Rebroadcast needs bool for new broadcast pattern. Got {broad}"
)
def __hash__(self):
# Need special __hash__ as dict aren't hashable.
# no ambiguity because each item key is unique
items = sorted(self.axis.items())
return hash((type(self), tuple(items)))
def __str__(self):
if len(self.axis) == 0:
broadcast_pattern = []
else:
broadcast_pattern = ["?" for i in range(1 + max(self.axis.keys()))]
for k, v in self.axis.items():
broadcast_pattern[k] = str(int(v))
return f"{self.__class__.__name__}{{{','.join(broadcast_pattern)}}}"
def make_node(self, x):
if self.axis.keys() and (x.ndim <= max(self.axis.keys())):
raise ValueError("Trying to rebroadcast non-existent dimension")
t = x.type.clone(
broadcastable=[
self.axis.get(i, b) for i, b in enumerate(x.type.broadcastable)
]
)
return Apply(self, [x], [t()])
def perform(self, node, inp, out_):
(x,) = inp
(out,) = out_
for axis, value in self.axis.items():
if value and x.shape[axis] != 1:
raise ValueError(
f"Dimension {axis} in Rebroadcast's input was"
f" supposed to be 1 (got {x.shape[axis]} instead)"
)
out[0] = x
def grad(self, inp, grads):
(x,) = inp
(gz,) = grads
# restore the broadcasting pattern of the input
return (
Rebroadcast(
*[
(axis, x.type.broadcastable[axis])
for axis, value in self.axis.items()
]
)(gz),
)
def infer_shape(self, fgraph, node, ishapes):
assert len(ishapes) == 1
l = []
one = theano.tensor.basic.constant(1)
for ax in range(len(ishapes[0])):
if self.axis.get(ax, False):
l.append(one)
else:
l.append(ishapes[0][ax])
return [tuple(l)]
def R_op(self, inputs, eval_points):
if eval_points[0] is None:
return [None]
return self(*eval_points, **dict(return_list=True))
def c_code(self, node, nodename, inp, out, sub):
(iname,) = inp
(oname,) = out
fail = sub["fail"]
itype = node.inputs[0].type.__class__
if itype in self.c_code_and_version:
code, version = self.c_code_and_version[itype]
final_code = ""
for axis, value in self.axis.items():
if value:
final_code += code % locals()
return (
final_code
+ f"""
Py_XDECREF({oname});
{oname} = {iname};
Py_XINCREF({oname});
"""
)
raise NotImplementedError()
def c_code_cache_version(self):
version = []
# If any of the c code is unversionned, we have to return ()
# Else, we will return a list of (type name, version) pairs.
for t, (c, v) in sorted(
self.c_code_and_version.items(), key=lambda pair: str(pair[0])
):
if not v:
warnings.warn(
f"Type {t} has C code for Rebroadcast, but it "
"has no version. You should add a 'version' "
"keyword arg when calling "
"register_rebroadcast_c_code.",
stacklevel=2,
)
return ()
version.append((str(t), v))
if version:
version.append(1)
return tuple(version)
...@@ -8,6 +8,7 @@ import numpy as np ...@@ -8,6 +8,7 @@ import numpy as np
import theano import theano
import theano.scalar as ts import theano.scalar as ts
import theano.tensor as tt import theano.tensor as tt
import theano.tensor.basic
from theano.compile import SharedVariable from theano.compile import SharedVariable
from theano.configdefaults import config from theano.configdefaults import config
from theano.graph.basic import Constant, Variable from theano.graph.basic import Constant, Variable
...@@ -856,7 +857,7 @@ theano.compile.register_deep_copy_op_c_code( ...@@ -856,7 +857,7 @@ theano.compile.register_deep_copy_op_c_code(
version=(5,), version=(5,),
) )
theano.compile.register_rebroadcast_c_code( theano.tensor.basic.register_rebroadcast_c_code(
GpuArrayType, GpuArrayType,
""" """
if(%(iname)s->ga.dimensions[%(axis)s] != 1){ if(%(iname)s->ga.dimensions[%(axis)s] != 1){
......
...@@ -18,12 +18,11 @@ import numpy as np ...@@ -18,12 +18,11 @@ import numpy as np
import theano.tensor as tt import theano.tensor as tt
from theano.compile import optdb from theano.compile import optdb
from theano.compile.ops import Rebroadcast
from theano.configdefaults import config from theano.configdefaults import config
from theano.graph.basic import Apply, Variable, clone_replace, is_in_ancestors from theano.graph.basic import Apply, Variable, clone_replace, is_in_ancestors
from theano.graph.op import _NoPythonOp from theano.graph.op import _NoPythonOp
from theano.graph.opt import GlobalOptimizer, local_optimizer from theano.graph.opt import GlobalOptimizer, local_optimizer
from theano.tensor import opt from theano.tensor import basic, opt
from theano.tensor.shape import Reshape, Shape, SpecifyShape from theano.tensor.shape import Reshape, Shape, SpecifyShape
from theano.tensor.type import TensorType from theano.tensor.type import TensorType
...@@ -485,7 +484,7 @@ acceptable_ops = ( ...@@ -485,7 +484,7 @@ acceptable_ops = (
Shape, Shape,
SpecifyShape, SpecifyShape,
Reshape, Reshape,
Rebroadcast, basic.Rebroadcast,
tt.math.Dot, tt.math.Dot,
tt.math.MaxAndArgmax, tt.math.MaxAndArgmax,
tt.subtensor.Subtensor, tt.subtensor.Subtensor,
......
...@@ -6,7 +6,7 @@ import jax ...@@ -6,7 +6,7 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
import jax.scipy as jsp import jax.scipy as jsp
from theano.compile.ops import DeepCopyOp, Rebroadcast, ViewOp from theano.compile.ops import DeepCopyOp, ViewOp
from theano.configdefaults import config from theano.configdefaults import config
from theano.graph.fg import FunctionGraph from theano.graph.fg import FunctionGraph
from theano.graph.type import CType from theano.graph.type import CType
...@@ -20,6 +20,7 @@ from theano.tensor.basic import ( ...@@ -20,6 +20,7 @@ from theano.tensor.basic import (
ARange, ARange,
Eye, Eye,
Join, Join,
Rebroadcast,
ScalarFromTensor, ScalarFromTensor,
TensorFromScalar, TensorFromScalar,
) )
......
...@@ -8,6 +8,7 @@ manipulation of tensors. ...@@ -8,6 +8,7 @@ manipulation of tensors.
import builtins import builtins
import logging import logging
import warnings import warnings
from collections import OrderedDict
from collections.abc import Sequence from collections.abc import Sequence
import numpy as np import numpy as np
...@@ -17,7 +18,6 @@ import theano.scalar.sharedvar ...@@ -17,7 +18,6 @@ import theano.scalar.sharedvar
from theano import compile, config, printing from theano import compile, config, printing
from theano import scalar as ts from theano import scalar as ts
from theano.assert_op import Assert, assert_op from theano.assert_op import Assert, assert_op
from theano.compile.ops import Rebroadcast
from theano.gradient import DisconnectedType, grad_not_implemented, grad_undefined from theano.gradient import DisconnectedType, grad_not_implemented, grad_undefined
from theano.graph.basic import Apply, Constant, Variable from theano.graph.basic import Apply, Constant, Variable
from theano.graph.op import COp, Op from theano.graph.op import COp, Op
...@@ -663,6 +663,195 @@ class ScalarFromTensor(COp): ...@@ -663,6 +663,195 @@ class ScalarFromTensor(COp):
scalar_from_tensor = ScalarFromTensor() scalar_from_tensor = ScalarFromTensor()
class Rebroadcast(COp):
"""
Change the input's broadcastable fields in some predetermined way.
See Also
--------
unbroadcast <theano.tensor.unbroadcast>
addbroadcast <theano.tensor.addbroadcast>
patternbroadcast <theano.tensor.patternbroadcast>
Notes
-----
Works inplace and works for CudaNdarrayType.
Examples
--------
`Rebroadcast((0, True), (1, False))(x)` would make `x` broadcastable in
axis 0 and not broadcastable in axis 1.
"""
view_map = {0: [0]}
_f16_ok = True
# Mapping from Type to C code (and version) to use.
# In the C code, the name of the input variable is %(iname)s,
# the output variable is %(oname)s.
c_code_and_version = {}
check_input = False
__props__ = ("axis",)
_f16_ok = True
def __init__(self, *axis):
# Sort them to make sure we merge all possible case.
items = sorted(axis)
self.axis = OrderedDict(items)
for axis, broad in self.axis.items():
if not isinstance(axis, (np.integer, int)):
raise TypeError(f"Rebroadcast needs integer axes. Got {axis}")
if not isinstance(broad, (np.bool_, bool)):
raise TypeError(
f"Rebroadcast needs bool for new broadcast pattern. Got {broad}"
)
def __hash__(self):
# Need special __hash__ as dict aren't hashable.
# no ambiguity because each item key is unique
items = sorted(self.axis.items())
return hash((type(self), tuple(items)))
def __str__(self):
if len(self.axis) == 0:
broadcast_pattern = []
else:
broadcast_pattern = ["?" for i in range(1 + max(self.axis.keys()))]
for k, v in self.axis.items():
broadcast_pattern[k] = str(int(v))
return f"{self.__class__.__name__}{{{','.join(broadcast_pattern)}}}"
def make_node(self, x):
if self.axis.keys() and (x.ndim <= max(self.axis.keys())):
raise ValueError("Trying to rebroadcast non-existent dimension")
t = x.type.clone(
broadcastable=[
self.axis.get(i, b) for i, b in enumerate(x.type.broadcastable)
]
)
return Apply(self, [x], [t()])
def perform(self, node, inp, out_):
(x,) = inp
(out,) = out_
for axis, value in self.axis.items():
if value and x.shape[axis] != 1:
raise ValueError(
f"Dimension {axis} in Rebroadcast's input was"
f" supposed to be 1 (got {x.shape[axis]} instead)"
)
out[0] = x
def grad(self, inp, grads):
(x,) = inp
(gz,) = grads
# restore the broadcasting pattern of the input
return (
Rebroadcast(
*[
(axis, x.type.broadcastable[axis])
for axis, value in self.axis.items()
]
)(gz),
)
def infer_shape(self, fgraph, node, ishapes):
assert len(ishapes) == 1
l = []
one = theano.tensor.basic.constant(1)
for ax in range(len(ishapes[0])):
if self.axis.get(ax, False):
l.append(one)
else:
l.append(ishapes[0][ax])
return [tuple(l)]
def R_op(self, inputs, eval_points):
if eval_points[0] is None:
return [None]
return self(*eval_points, **dict(return_list=True))
def c_code(self, node, nodename, inp, out, sub):
(iname,) = inp
(oname,) = out
fail = sub["fail"]
itype = node.inputs[0].type.__class__
if itype in self.c_code_and_version:
code, version = self.c_code_and_version[itype]
final_code = ""
for axis, value in self.axis.items():
if value:
final_code += code % locals()
return (
final_code
+ f"""
Py_XDECREF({oname});
{oname} = {iname};
Py_XINCREF({oname});
"""
)
raise NotImplementedError()
def c_code_cache_version(self):
version = []
# If any of the c code is unversionned, we have to return ()
# Else, we will return a list of (type name, version) pairs.
for t, (c, v) in sorted(
self.c_code_and_version.items(), key=lambda pair: str(pair[0])
):
if not v:
warnings.warn(
f"Type {t} has C code for Rebroadcast, but it "
"has no version. You should add a 'version' "
"keyword arg when calling "
"register_rebroadcast_c_code.",
stacklevel=2,
)
return ()
version.append((str(t), v))
if version:
version.append(1)
return tuple(version)
def register_rebroadcast_c_code(typ, code, version=()):
"""
Tell Rebroadcast how to generate C code for a Theano Type.
typ : Theano type
It must be the Theano class itself and not an instance of the class.
code : C code
That checks if the dimension %(axis)s is of shape 1 for the Theano type
'typ'. Use %(iname)s and %(oname)s for the input and output C variable
names respectively, and %(axis)s for the axis that we need to check.
This code is put in a loop for all axes.
version
A number indicating the version of the code, for cache.
"""
Rebroadcast.c_code_and_version[typ] = (code, version)
register_rebroadcast_c_code(
TensorType,
"""
if(PyArray_DIMS(%(iname)s)[%(axis)s] != 1){
PyErr_Format(PyExc_ValueError,
"Dimension %(axis)s in Rebroadcast's input was"
" supposed to be 1 (got %%d instead)",
PyArray_DIMS(%(iname)s)[%(axis)s]);
%(fail)s
}
""",
version=1,
)
# to be removed as we get the epydoc routine-documenting thing going # to be removed as we get the epydoc routine-documenting thing going
# -JB 20080924 # -JB 20080924
def _conversion(real_value, name): def _conversion(real_value, name):
......
...@@ -762,21 +762,6 @@ theano.compile.register_deep_copy_op_c_code( ...@@ -762,21 +762,6 @@ theano.compile.register_deep_copy_op_c_code(
) )
theano.compile.register_rebroadcast_c_code(
TensorType,
"""
if(PyArray_DIMS(%(iname)s)[%(axis)s] != 1){
PyErr_Format(PyExc_ValueError,
"Dimension %(axis)s in Rebroadcast's input was"
" supposed to be 1 (got %%d instead)",
PyArray_DIMS(%(iname)s)[%(axis)s]);
%(fail)s
}
""",
version=1,
)
def tensor(*args, **kwargs): def tensor(*args, **kwargs):
name = kwargs.pop("name", None) name = kwargs.pop("name", None)
return TensorType(*args, **kwargs)(name=name) return TensorType(*args, **kwargs)(name=name)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论