提交 7b25ea4b authored 作者: James Bergstra's avatar James Bergstra

Moved get_constant_value to tensor.basic to use it in Alloc.make_node and/or infer_shape

上级 6fa0a8ae
......@@ -304,6 +304,41 @@ def _allclose(a, b):
rtol = float64_rtol
return numpy.allclose(a,b, atol=atol, rtol=rtol)
def get_constant_value(v):
"""return the constant scalar(0-D) value underlying variable `v`
If v is the output of dimshuffles, fills, allocs, rebroadcasts,
this function digs through them.
If `v` is not some view of constant data, then raise a TypeError.
:note: There may be another function similar to this one in the code, but I'm not sure where it
is.
"""
if isinstance(v, Constant):
#TODO: consider checking for arrays of the form e.g. [1,1,1,1] where
# it is not a constant, but in some cases it *could* be replaced with one.
# Note that this would have an effect on the broadcasting of inputs and so on
try:
complex(v.data) #works for all numeric scalars
return v.data
except:
raise TypeError(v)
if v.owner:
if isinstance(v.owner.op, Alloc):
return get_constant_value(v.owner.inputs[0])
if isinstance(v.owner.op, DimShuffle):
return get_constant_value(v.owner.inputs[0])
if isinstance(v.owner.op, Rebroadcast):
return get_constant_value(v.owner.inputs[0])
if v.owner.op == fill:
shape, val = v.owner.inputs
# fill(a,b) fills the shape of 'a' filled with 'b'
return get_constant_value(val)
raise TypeError(v)
class TensorType(Type):
"""Symbolic `Type` representing a numpy.ndarray value."""
......@@ -978,8 +1013,11 @@ class _tensor_py_operators:
#TO TRUMP NUMPY OPERATORS
__array_priority__ = 1000
def get_constant_value(self):
return get_constant_value(self)
class TensorVariable(Variable, _tensor_py_operators):
"""Subclass to add the tensor operators to the basic `Variable` class."""
TensorType.Variable = TensorVariable
......@@ -1754,7 +1792,11 @@ class Alloc(gof.Op):
if s.type.dtype[:3] not in ('int', 'uin'):
raise TypeError('Shape arguments must be integers', s)
# if s is constant 1, then we're broadcastable in that dim
bcast.append(isinstance(s, TensorConstant) and (s.data == 1))
try:
const_shp = get_constant_value(s)
except TypeError:
const_shp = None
bcast.append(numpy.all(1 == const_shp))
otype = TensorType(dtype=v.dtype, broadcastable=bcast)
return gof.Apply(self, [v]+sh, [otype()])
......
......@@ -25,6 +25,7 @@ from theano import compile #to register the optimizer built by this file
from theano.gof.python25 import any, all
from theano.gof.opt import Optimizer
from theano.gof import toolbox, DestroyHandler
from basic import get_constant_value
# Utilities
......@@ -63,40 +64,6 @@ def encompasses_broadcastable(b1, b2):
def merge_broadcastables(broadcastables):
return [all(bcast) for bcast in zip(*broadcastables)]
def get_constant_value(v):
"""return the constant scalar(0-D) value underlying variable `v`
If v is the output of dimshuffles, fills, allocs, rebroadcasts,
this function digs through them.
If `v` is not some view of constant data, then raise a TypeError.
:note: There may be another function similar to this one in the code, but I'm not sure where it
is.
"""
if isinstance(v, Constant):
#TODO: consider checking for arrays of the form e.g. [1,1,1,1] where
# it is not a constant, but in some cases it *could* be replaced with one.
# Note that this would have an effect on the broadcasting of inputs and so on
try:
complex(v.data) #works for all numeric scalars
return v.data
except:
raise TypeError(v)
if v.owner:
if isinstance(v.owner.op, T.Alloc):
return get_constant_value(v.owner.inputs[0])
if isinstance(v.owner.op, T.DimShuffle):
return get_constant_value(v.owner.inputs[0])
if isinstance(v.owner.op, T.Rebroadcast):
return get_constant_value(v.owner.inputs[0])
if v.owner.op == T.fill:
shape, val = v.owner.inputs
# fill(a,b) fills the shape of 'a' filled with 'b'
return get_constant_value(val)
raise TypeError(v)
def scalarconsts_rest(inputs):
"""Partition a list of variables into two kinds:
scalar constants, and the rest."""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论