Unverified 提交 98875d1e authored 作者: Thomas Wiecki's avatar Thomas Wiecki 提交者: GitHub

Merge pull request #21 from brandonwillard/jax-linker

Introduce a JAX Linker class
...@@ -60,6 +60,7 @@ install: ...@@ -60,6 +60,7 @@ install:
- conda create --yes -q -n pyenv python=$TRAVIS_PYTHON_VERSION - conda create --yes -q -n pyenv python=$TRAVIS_PYTHON_VERSION
- conda activate pyenv - conda activate pyenv
- conda install --yes -q mkl numpy scipy pip mkl-service graphviz cython libgpuarray pygpu - conda install --yes -q mkl numpy scipy pip mkl-service graphviz cython libgpuarray pygpu
- if [[ "$TRAVIS_PYTHON_VERSION" != "3.6" ]]; then conda install --yes -q -c conda-forge jax jaxlib; fi
- pip install -q -r requirements.txt - pip install -q -r requirements.txt
- conda list && pip freeze - conda list && pip freeze
- python -c 'import theano; print(theano.config.__str__(print_doc=False))' - python -c 'import theano; print(theano.config.__str__(print_doc=False))'
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
flake8 flake8
pep8 pep8
pyflakes pyflakes
black==20.8b1 black==20.8b1; platform.python_implementation!='PyPy'
pytest-cov>=2.6.1 pytest-cov>=2.6.1
coverage>=5.1 coverage>=5.1
pytest pytest
...@@ -10,3 +10,5 @@ coveralls ...@@ -10,3 +10,5 @@ coveralls
cython cython
sympy sympy
versioneer versioneer
jax; python_version > '3.6'
jaxlib; python_version > '3.6'
差异被折叠。
...@@ -7,12 +7,13 @@ import logging ...@@ -7,12 +7,13 @@ import logging
import warnings import warnings
import theano import theano
from theano import gof
import theano.gof.vm import theano.gof.vm
from theano import config
from six import string_types from six import string_types
from theano.compile.function_module import Supervisor
from theano import config, gof
from theano.compile.function_module import Supervisor
from theano.sandbox.jax_linker import JAXLinker
_logger = logging.getLogger("theano.compile.mode") _logger = logging.getLogger("theano.compile.mode")
...@@ -29,6 +30,7 @@ predefined_linkers = { ...@@ -29,6 +30,7 @@ predefined_linkers = {
"cvm": gof.vm.VM_Linker(use_cloop=True), # Use allow_gc Theano flag "cvm": gof.vm.VM_Linker(use_cloop=True), # Use allow_gc Theano flag
"vm_nogc": gof.vm.VM_Linker(allow_gc=False, use_cloop=False), "vm_nogc": gof.vm.VM_Linker(allow_gc=False, use_cloop=False),
"cvm_nogc": gof.vm.VM_Linker(allow_gc=False, use_cloop=True), "cvm_nogc": gof.vm.VM_Linker(allow_gc=False, use_cloop=True),
"jax": JAXLinker(),
} }
...@@ -411,9 +413,15 @@ if theano.config.cxx: ...@@ -411,9 +413,15 @@ if theano.config.cxx:
else: else:
FAST_RUN = Mode("vm", "fast_run") FAST_RUN = Mode("vm", "fast_run")
JAX = Mode(
JAXLinker(), gof.Query(include=["fast_run"], exclude=["cxx_only", "BlasOpt"])
)
predefined_modes = { predefined_modes = {
"FAST_COMPILE": FAST_COMPILE, "FAST_COMPILE": FAST_COMPILE,
"FAST_RUN": FAST_RUN, "FAST_RUN": FAST_RUN,
"JAX": JAX,
} }
instantiated_default_mode = None instantiated_default_mode = None
......
...@@ -596,14 +596,16 @@ AddConfigVar( ...@@ -596,14 +596,16 @@ AddConfigVar(
# Also, please be careful not to modify the first item in the enum when adding # Also, please be careful not to modify the first item in the enum when adding
# new modes, since it is the default mode. # new modes, since it is the default mode.
def filter_mode(val): def filter_mode(val):
if val in [ if (
val
in [
"Mode", "Mode",
"DebugMode", "DebugMode",
"FAST_RUN",
"NanGuardMode", "NanGuardMode",
"FAST_COMPILE",
"DEBUG_MODE", "DEBUG_MODE",
]: ]
or val in theano.compile.mode.predefined_modes
):
return val return val
# This can be executed before Theano is completly imported, so # This can be executed before Theano is completly imported, so
# theano.Mode is not always available. # theano.Mode is not always available.
......
from warnings import warn
from collections.abc import Sequence
from theano.gof.link import (
PerformLinker,
map_storage,
gc_helper,
utils,
add_clear_storage,
Container,
streamline,
)
from theano.gof.graph import Constant
class JAXLinker(PerformLinker):
"""A `Linker` that JIT-compiles NumPy-based operations using JAX.
Attributes
----------
allow_non_jax: bool
A boolean indicating whether or not an exception is thrown when the
graph cannot be JAX compiled (e.g. the graph has an unsupported operator).
If `allow_non_jax` is `True`, the fallback is currently Python compilation.
"""
allow_non_jax = False
def create_jax_thunks(self, compute_map, storage_map):
"""Create a thunk for each output of the `Linker`s `FunctionGraph`.
This is differs from the other thunk-making function in that it only
produces thunks for the `FunctionGraph` output nodes.
Parameters
----------
compute_map: dict
The compute map dictionary.
storage_map: dict
The storage map dictionary.
Returns
-------
thunks: list
A tuple containing the thunks.
output_nodes: list and their
A tuple containing the output nodes.
"""
import jax
from theano.sandbox.jaxify import jax_funcify
output_nodes = [o.owner for o in self.fgraph.outputs]
# Create a JAX-compilable function from our `FunctionGraph`
jaxed_fgraph_outputs = jax_funcify(self.fgraph)
assert len(jaxed_fgraph_outputs) == len(output_nodes)
# I suppose we can consider `Constant`s to be "static" according to
# JAX.
static_argnums = [
n for n, i in enumerate(self.fgraph.inputs) if isinstance(i, Constant)
]
thunk_inputs = [storage_map[n] for n in self.fgraph.inputs]
thunks = []
for node, jax_funcs in zip(output_nodes, jaxed_fgraph_outputs):
thunk_outputs = [storage_map[n] for n in node.outputs]
# JIT-compile the functions
if len(node.outputs) > 1:
assert len(jax_funcs) == len(node.ouptputs)
jax_impl_jits = [
jax.jit(jax_func, static_argnums) for jax_func in jax_funcs
]
else:
assert not isinstance(jax_funcs, Sequence)
jax_impl_jits = [jax.jit(jax_funcs, static_argnums)]
def thunk(
node=node, jax_impl_jits=jax_impl_jits, thunk_outputs=thunk_outputs
):
outputs = [
jax_impl_jit(*[x[0] for x in thunk_inputs])
for jax_impl_jit in jax_impl_jits
]
for o_node, o_storage, o_val in zip(
node.outputs, thunk_outputs, outputs
):
compute_map[o_node][0] = True
if len(o_storage) > 1:
assert len(o_storage) == len(o_val)
for i, o_sub_val in enumerate(o_val):
o_storage[i] = o_sub_val
else:
o_storage[0] = o_val
return outputs
thunk.inputs = thunk_inputs
thunk.outputs = thunk_outputs
thunk.lazy = False
thunks.append(thunk)
return thunks, output_nodes
def make_all(self, input_storage=None, output_storage=None, storage_map=None):
fgraph = self.fgraph
nodes = self.schedule(fgraph)
no_recycling = self.no_recycling
input_storage, output_storage, storage_map = map_storage(
fgraph, nodes, input_storage, output_storage, storage_map
)
compute_map = {}
for k in storage_map:
compute_map[k] = [k.owner is None]
try:
# We need to create thunk functions that will populate the output
# storage arrays with the JAX-computed values.
thunks, nodes = self.create_jax_thunks(compute_map, storage_map)
except NotImplementedError as e:
if not self.allow_non_jax:
raise
warn("JaxLinker could not JAXify graph: {}".format(e))
thunks = []
for node in nodes:
thunk = node.op.make_thunk(
node, storage_map, compute_map, no_recycling, "py"
)
thunk_inputs = [storage_map[v] for v in node.inputs]
thunk_outputs = [storage_map[v] for v in node.outputs]
thunk.inputs = thunk_inputs
thunk.outputs = thunk_outputs
thunks.append(thunk)
computed, last_user = gc_helper(nodes)
if self.allow_gc:
post_thunk_old_storage = []
for node in nodes:
post_thunk_old_storage.append(
[
storage_map[input]
for input in node.inputs
if (input in computed)
and (input not in fgraph.outputs)
and (node == last_user[input])
]
)
else:
post_thunk_old_storage = None
if no_recycling is True:
no_recycling = list(storage_map.values())
no_recycling = utils.difference(no_recycling, input_storage)
else:
no_recycling = [
storage_map[r] for r in no_recycling if r not in fgraph.inputs
]
fn = streamline(
fgraph, thunks, nodes, post_thunk_old_storage, no_recycling=no_recycling
)
fn.allow_gc = self.allow_gc
add_clear_storage(fn, computed, storage_map)
fn.storage_map = storage_map
return (
fn,
[
Container(input, storage)
for input, storage in zip(fgraph.inputs, input_storage)
],
[
Container(output, storage, True)
for output, storage in zip(fgraph.outputs, output_storage)
],
thunks,
nodes,
)
差异被折叠。
...@@ -1767,6 +1767,7 @@ class Maximum(BinaryScalarOp): ...@@ -1767,6 +1767,7 @@ class Maximum(BinaryScalarOp):
commutative = True commutative = True
associative = True associative = True
nfunc_spec = ("maximum", 2, 1) nfunc_spec = ("maximum", 2, 1)
nfunc_variadic = "maximum"
def impl(self, *inputs): def impl(self, *inputs):
# The built-in max function don't support complex type # The built-in max function don't support complex type
...@@ -1811,6 +1812,7 @@ class Minimum(BinaryScalarOp): ...@@ -1811,6 +1812,7 @@ class Minimum(BinaryScalarOp):
commutative = True commutative = True
associative = True associative = True
nfunc_spec = ("minimum", 2, 1) nfunc_spec = ("minimum", 2, 1)
nfunc_variadic = "minimum"
def impl(self, *inputs): def impl(self, *inputs):
# The built-in min function don't support complex type # The built-in min function don't support complex type
...@@ -1855,6 +1857,7 @@ class Add(ScalarOp): ...@@ -1855,6 +1857,7 @@ class Add(ScalarOp):
commutative = True commutative = True
associative = True associative = True
nfunc_spec = ("add", 2, 1) nfunc_spec = ("add", 2, 1)
nfunc_variadic = "sum"
def impl(self, *inputs): def impl(self, *inputs):
return sum(inputs) return sum(inputs)
...@@ -1896,6 +1899,7 @@ class Mul(ScalarOp): ...@@ -1896,6 +1899,7 @@ class Mul(ScalarOp):
commutative = True commutative = True
associative = True associative = True
nfunc_spec = ("multiply", 2, 1) nfunc_spec = ("multiply", 2, 1)
nfunc_variadic = "product"
def impl(self, *inputs): def impl(self, *inputs):
return np.product(inputs) return np.product(inputs)
...@@ -2984,6 +2988,8 @@ class Inv(UnaryScalarOp): ...@@ -2984,6 +2988,8 @@ class Inv(UnaryScalarOp):
""" """
nfunc_spec = ("reciprocal", 1, 1)
def impl(self, x): def impl(self, x):
return np.float32(1.0) / x return np.float32(1.0) / x
......
...@@ -1787,6 +1787,20 @@ def max_and_argmax(a, axis=None, keepdims=False): ...@@ -1787,6 +1787,20 @@ def max_and_argmax(a, axis=None, keepdims=False):
return [out, argout] return [out, argout]
class Max(CAReduce):
nfunc_spec = ("max", 1, 1)
def __init__(self, axis):
super().__init__(scal.maximum, axis)
class Min(CAReduce):
nfunc_spec = ("min", 1, 1)
def __init__(self, axis):
super().__init__(scal.minimum, axis)
@constructor @constructor
def max(x, axis=None, keepdims=False): def max(x, axis=None, keepdims=False):
""" """
...@@ -1823,7 +1837,7 @@ def max(x, axis=None, keepdims=False): ...@@ -1823,7 +1837,7 @@ def max(x, axis=None, keepdims=False):
try: try:
out = max_and_argmax(x, axis)[0] out = max_and_argmax(x, axis)[0]
except Exception: except Exception:
out = CAReduce(scal.maximum, axis)(x) out = Max(axis)(x)
if keepdims: if keepdims:
out = makeKeepDims(x, out, axis) out = makeKeepDims(x, out, axis)
...@@ -3416,7 +3430,7 @@ def prod( ...@@ -3416,7 +3430,7 @@ def prod(
class Mean(elemwise.CAReduce): class Mean(elemwise.CAReduce):
def __init__(self, axis=None): def __init__(self, axis=None):
elemwise.CAReduce.__init__(self, scal.add, axis) super().__init__(scal.add, axis)
assert self.axis is None or len(self.axis) == 1 assert self.axis is None or len(self.axis) == 1
def __str__(self): def __str__(self):
...@@ -3443,7 +3457,7 @@ class Mean(elemwise.CAReduce): ...@@ -3443,7 +3457,7 @@ class Mean(elemwise.CAReduce):
def c_code(self, node, name, inames, onames, sub): def c_code(self, node, name, inames, onames, sub):
if self.axis is not None: if self.axis is not None:
return super(Op, self).c_code(node, name, inames, onames, sub) return super(Op, self).c_code(node, name, inames, onames, sub)
ret = elemwise.CAReduce.c_code(self, node, name, inames, onames, sub) ret = super().c_code(self, node, name, inames, onames, sub)
# TODO: c_code perform support only axis is None # TODO: c_code perform support only axis is None
return ( return (
ret ret
......
...@@ -1761,6 +1761,7 @@ class All(CAReduce): ...@@ -1761,6 +1761,7 @@ class All(CAReduce):
""" """
__props__ = ("axis",) __props__ = ("axis",)
nfunc_spec = ("all", 1, 1)
def __init__(self, axis=None): def __init__(self, axis=None):
CAReduce.__init__(self, scalar.and_, axis) CAReduce.__init__(self, scalar.and_, axis)
...@@ -1793,6 +1794,7 @@ class Any(CAReduce): ...@@ -1793,6 +1794,7 @@ class Any(CAReduce):
""" """
__props__ = ("axis",) __props__ = ("axis",)
nfunc_spec = ("any", 1, 1)
def __init__(self, axis=None): def __init__(self, axis=None):
CAReduce.__init__(self, scalar.or_, axis) CAReduce.__init__(self, scalar.or_, axis)
...@@ -2027,6 +2029,7 @@ class Sum(CAReduceDtype): ...@@ -2027,6 +2029,7 @@ class Sum(CAReduceDtype):
""" """
__props__ = ("axis", "dtype", "acc_dtype") __props__ = ("axis", "dtype", "acc_dtype")
nfunc_spec = ("sum", 1, 1)
def __init__(self, axis=None, dtype=None, acc_dtype=None): def __init__(self, axis=None, dtype=None, acc_dtype=None):
CAReduceDtype.__init__( CAReduceDtype.__init__(
...@@ -2085,6 +2088,7 @@ class Prod(CAReduceDtype): ...@@ -2085,6 +2088,7 @@ class Prod(CAReduceDtype):
""" """
__props__ = ("axis", "dtype", "acc_dtype") __props__ = ("axis", "dtype", "acc_dtype")
nfunc_spec = ("sum", 1, 1)
def __init__(self, axis=None, dtype=None, acc_dtype=None, no_zeros_in_input=False): def __init__(self, axis=None, dtype=None, acc_dtype=None, no_zeros_in_input=False):
CAReduceDtype.__init__( CAReduceDtype.__init__(
......
...@@ -31,6 +31,8 @@ class ScalarSigmoid(scalar.UnaryScalarOp): ...@@ -31,6 +31,8 @@ class ScalarSigmoid(scalar.UnaryScalarOp):
""" """
nfunc_spec = ("scipy.special.expit", 1, 1)
@staticmethod @staticmethod
def st_impl(x): def st_impl(x):
if x < -30.0: if x < -30.0:
...@@ -196,6 +198,8 @@ class UltraFastScalarSigmoid(scalar.UnaryScalarOp): ...@@ -196,6 +198,8 @@ class UltraFastScalarSigmoid(scalar.UnaryScalarOp):
""" """
nfunc_spec = ("scipy.special.expit", 1, 1)
@staticmethod @staticmethod
def st_impl(x): def st_impl(x):
x = 0.5 * x x = 0.5 * x
......
...@@ -31,44 +31,40 @@ supposed to be canonical. ...@@ -31,44 +31,40 @@ supposed to be canonical.
""" """
# TODO: intelligent merge for mul/add
# TODO: 0*x -> 0
import logging import logging
from theano import gof import theano.tensor.basic as tt
from theano.tensor.elemwise import CAReduce import theano.scalar.basic as scal
from theano.tensor import basic as T
from theano.tensor import DimShuffle, Subtensor
from theano.gof.opt import copy_stack_trace, local_optimizer
from theano.tensor.subtensor import Subtensor
from theano.tensor.elemwise import CAReduce, DimShuffle
from theano.tensor.opt import register_uncanonicalize from theano.tensor.opt import register_uncanonicalize
from theano import scalar as scal
from theano.gof.opt import copy_stack_trace
_logger = logging.getLogger("theano.tensor.opt") _logger = logging.getLogger("theano.tensor.opt")
@register_uncanonicalize @register_uncanonicalize
@gof.local_optimizer([T.MaxAndArgmax]) @local_optimizer([tt.MaxAndArgmax])
def local_max_and_argmax(node): def local_max_and_argmax(node):
""" """
If we don't use the argmax, change it to a max only. If we don't use the argmax, change it to a max only.
""" """
if isinstance(node.op, T.MaxAndArgmax): if isinstance(node.op, tt.MaxAndArgmax):
axis = node.op.get_params(node) axis = node.op.get_params(node)
if len(node.outputs[1].clients) == 0: if len(node.outputs[1].clients) == 0:
new = CAReduce(scal.maximum, axis)(node.inputs[0]) new = tt.Max(axis)(node.inputs[0])
copy_stack_trace(node.outputs[0], new) copy_stack_trace(node.outputs[0], new)
return [new, None] return [new, None]
if len(node.outputs[0].clients) == 0: if len(node.outputs[0].clients) == 0:
new = T.Argmax(axis)(node.inputs[0]) new = tt.Argmax(axis)(node.inputs[0])
copy_stack_trace(node.outputs[0], new) copy_stack_trace(node.outputs[0], new)
return [None, new] return [None, new]
@register_uncanonicalize @register_uncanonicalize
@gof.local_optimizer([T.neg]) @local_optimizer([tt.neg])
def local_max_to_min(node): def local_max_to_min(node):
""" """
Change -(max(-x)) to min. Change -(max(-x)) to min.
...@@ -81,7 +77,7 @@ def local_max_to_min(node): ...@@ -81,7 +77,7 @@ def local_max_to_min(node):
the interface put only MaxAndArgmax into the graph. the interface put only MaxAndArgmax into the graph.
""" """
if node.op == T.neg and node.inputs[0].owner: if node.op == tt.neg and node.inputs[0].owner:
max = node.inputs[0] max = node.inputs[0]
if ( if (
max.owner max.owner
...@@ -89,15 +85,15 @@ def local_max_to_min(node): ...@@ -89,15 +85,15 @@ def local_max_to_min(node):
and max.owner.op.scalar_op == scal.maximum and max.owner.op.scalar_op == scal.maximum
): ):
neg = max.owner.inputs[0] neg = max.owner.inputs[0]
if neg.owner and neg.owner.op == T.neg: if neg.owner and neg.owner.op == tt.neg:
new = CAReduce(scal.minimum, max.owner.op.axis)(neg.owner.inputs[0]) new = tt.Min(max.owner.op.axis)(neg.owner.inputs[0])
return [copy_stack_trace(node.outputs[0], new)] return [copy_stack_trace(node.outputs[0], new)]
return False return False
@register_uncanonicalize @register_uncanonicalize
@gof.local_optimizer([T.Alloc]) @local_optimizer([tt.Alloc])
def local_alloc_dimshuffle(node): def local_alloc_dimshuffle(node):
""" """
If a dimshuffle is inside an alloc and only adds dimension to the If a dimshuffle is inside an alloc and only adds dimension to the
...@@ -105,7 +101,7 @@ def local_alloc_dimshuffle(node): ...@@ -105,7 +101,7 @@ def local_alloc_dimshuffle(node):
Alloc(DimShuffle(x), ...) - > Alloc(x, ...) Alloc(DimShuffle(x), ...) - > Alloc(x, ...)
""" """
if isinstance(node.op, T.Alloc): if isinstance(node.op, tt.Alloc):
input_ = node.inputs[0] input_ = node.inputs[0]
if input_.owner and isinstance(input_.owner.op, DimShuffle): if input_.owner and isinstance(input_.owner.op, DimShuffle):
# check if it only adds dimension to the left # check if it only adds dimension to the left
...@@ -115,12 +111,12 @@ def local_alloc_dimshuffle(node): ...@@ -115,12 +111,12 @@ def local_alloc_dimshuffle(node):
) + tuple(range(input_.owner.inputs[0].ndim)) ) + tuple(range(input_.owner.inputs[0].ndim))
if new_order != expected_new_order: if new_order != expected_new_order:
return False return False
return [T.alloc(input_.owner.inputs[0], *node.inputs[1:])] return [tt.alloc(input_.owner.inputs[0], *node.inputs[1:])]
return False return False
@register_uncanonicalize @register_uncanonicalize
@gof.local_optimizer([T.Reshape]) @local_optimizer([tt.Reshape])
def local_reshape_dimshuffle(node): def local_reshape_dimshuffle(node):
""" """
If a dimshuffle is inside a reshape and does not change the order If a dimshuffle is inside a reshape and does not change the order
...@@ -128,7 +124,7 @@ def local_reshape_dimshuffle(node): ...@@ -128,7 +124,7 @@ def local_reshape_dimshuffle(node):
Reshape(Dimshuffle(x), shp) -> Reshape(x, shp) Reshape(Dimshuffle(x), shp) -> Reshape(x, shp)
""" """
if isinstance(node.op, T.Reshape): if isinstance(node.op, tt.Reshape):
input_ = node.inputs[0] input_ = node.inputs[0]
if input_.owner and isinstance(input_.owner.op, DimShuffle): if input_.owner and isinstance(input_.owner.op, DimShuffle):
new_order = input_.owner.op.new_order new_order = input_.owner.op.new_order
...@@ -141,7 +137,7 @@ def local_reshape_dimshuffle(node): ...@@ -141,7 +137,7 @@ def local_reshape_dimshuffle(node):
else: else:
offset += 1 offset += 1
return [ return [
T.reshape( tt.reshape(
input_.owner.inputs[0], node.inputs[1], ndim=node.outputs[0].ndim input_.owner.inputs[0], node.inputs[1], ndim=node.outputs[0].ndim
) )
] ]
...@@ -149,7 +145,7 @@ def local_reshape_dimshuffle(node): ...@@ -149,7 +145,7 @@ def local_reshape_dimshuffle(node):
@register_uncanonicalize @register_uncanonicalize
@gof.local_optimizer([DimShuffle]) @local_optimizer([DimShuffle])
def local_dimshuffle_alloc(node): def local_dimshuffle_alloc(node):
""" """
If an alloc is inside a dimshuffle which only adds dimension to the left, If an alloc is inside a dimshuffle which only adds dimension to the left,
...@@ -159,7 +155,7 @@ def local_dimshuffle_alloc(node): ...@@ -159,7 +155,7 @@ def local_dimshuffle_alloc(node):
""" """
if isinstance(node.op, DimShuffle) and node.inputs[0].owner: if isinstance(node.op, DimShuffle) and node.inputs[0].owner:
input_ = node.inputs[0] input_ = node.inputs[0]
if isinstance(input_.owner.op, T.Alloc): if isinstance(input_.owner.op, tt.Alloc):
# check if it only adds dimension to the left # check if it only adds dimension to the left
new_order = node.op.new_order new_order = node.op.new_order
expected_new_order = ("x",) * (len(new_order) - input_.ndim) + tuple( expected_new_order = ("x",) * (len(new_order) - input_.ndim) + tuple(
...@@ -172,12 +168,12 @@ def local_dimshuffle_alloc(node): ...@@ -172,12 +168,12 @@ def local_dimshuffle_alloc(node):
nb_new_dims = len(new_order) - input_.ndim nb_new_dims = len(new_order) - input_.ndim
new_shape_input = (1,) * nb_new_dims + tuple(input_.owner.inputs[1:]) new_shape_input = (1,) * nb_new_dims + tuple(input_.owner.inputs[1:])
return [T.alloc(input_.owner.inputs[0], *new_shape_input)] return [tt.alloc(input_.owner.inputs[0], *new_shape_input)]
return False return False
@register_uncanonicalize @register_uncanonicalize
@gof.local_optimizer([DimShuffle]) @local_optimizer([DimShuffle])
def local_dimshuffle_subtensor(node): def local_dimshuffle_subtensor(node):
"""If a subtensor is inside a dimshuffle which only drop """If a subtensor is inside a dimshuffle which only drop
broadcastable dimensions, scrap the dimshuffle and index the broadcastable dimensions, scrap the dimshuffle and index the
...@@ -223,7 +219,7 @@ def local_dimshuffle_subtensor(node): ...@@ -223,7 +219,7 @@ def local_dimshuffle_subtensor(node):
# tensor was indexed such as x[scalar, :, :], check that as well # tensor was indexed such as x[scalar, :, :], check that as well
new_idx_list = list(input_.owner.op.idx_list) new_idx_list = list(input_.owner.op.idx_list)
new_inputs = [input_.owner.inputs[0]] new_inputs = [input_.owner.inputs[0]]
zero = T.constant(0) zero = tt.constant(0)
slice_attr_list = ["start", "stop", "step"] slice_attr_list = ["start", "stop", "step"]
j = 0 j = 0
slice_i = -1 slice_i = -1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论