提交 0f2fd4b4 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add a clone method to TensorType and its equivalents.

上级 06cc52d7
......@@ -570,7 +570,7 @@ class Rebroadcast(gof.Op):
def __hash__(self):
items = sorted(self.axis.iteritems()) # no ambiguity because each item key is unique
return hash(type(self)) ^ hash(tuple(items))
return hash((type(self), tuple(items)))
def __str__(self):
if len(self.axis) == 0:
......@@ -586,10 +586,9 @@ class Rebroadcast(gof.Op):
def make_node(self, x):
if self.axis.keys() and (x.ndim <= numpy.max(self.axis.keys())):
raise ValueError('Trying to rebroadcast non-existent dimension')
t = x.type.__class__(dtype=x.type.dtype,
broadcastable=[self.axis.get(i, b)
for i, b in enumerate(
x.type.broadcastable)])
t = x.type.clone(broadcastable=[self.axis.get(i, b)
for i, b in enumerate(
x.type.broadcastable)])
return gof.Apply(self, [x], [t()])
def perform(self, node, inp, out_):
......
......@@ -71,6 +71,11 @@ class CudaNdarrayType(Type):
self.name = name
self.dtype_specs() # error checking is done there
def clone(self, dtype=None, broadcastable=None):
if broadcastable is None:
broadcastable = self.broadcastable
return self.__class__(broadcastable, name=self.name, dtype=dtype)
def filter(self, data, strict=False, allow_downcast=None):
return self.filter_inplace(data, None, strict=strict,
allow_downcast=allow_downcast)
......
......@@ -607,7 +607,6 @@ class GpuAlloc(HideC, Alloc):
def __init__(self, memset_0=False):
"""memset_0 is only an optimized version. True, it mean the
value is always 0, so the c code call memset as it is faster.
"""
self.memset_0 = memset_0
......
......@@ -28,6 +28,14 @@ class GpuArrayType(Type):
raise TypeError("Unsupported dtype for %s: %s" %
(self.__class__.__name__, self.dtype))
def clone(self, dtype=None, broadcastable=None):
if dtype is None:
dtype = self.dtype
if broadcastable is None:
broadcastable = self.broadcastable
return self.__class__(dtype=dtype, broadcastable=broadcastable,
name=self.name)
def __str__(self):
return "GpuArrayType(%s, %s)" % (self.dtype, self.broadcastable)
......
......@@ -2403,17 +2403,7 @@ class Alloc(gof.Op):
This Op is used to replace fill() during optimizations because after shapes
are lifted, the first argument to fill can often be pruned from the graph.
"""
def __init__(self):
pass
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def __str__(self):
return self.__class__.__name__
__props__ = ()
def make_node(self, value, *shape):
v = as_tensor_variable(value)
......
......@@ -52,6 +52,18 @@ class TensorType(Type):
" AdvancedSubtensor1 sparse_grad. Now use"
" theano.sparse_grad(a_tensor[an_int_vector]).")
def clone(self, dtype=None, broadcastable=None):
"""
Return a copy of the type optionally with a new dtype or
broadcastable pattern.
"""
if dtype is None:
dtype = self.dtype
if broadcastable is None:
broadcastable = self.broadcastable
return self.__class__(dtype, broadcastable, name=self.name,
sparse_grad=self.sparse_grad)
def filter(self, data, strict=False, allow_downcast=None):
"""Convert `data` to something which can be associated to a
`TensorVariable`.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论