提交 067a3631 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Use type.filter_variable instead of filter_update

上级 f13ab8b7
......@@ -103,7 +103,7 @@ def rebuild_collect_shared( outputs
# Do not use default_update if a "real" update was
# provided
if v not in update_d:
v_update = v.filter_update(v.default_update)
v_update = v.type.filter_variable(v.default_update)
if v_update.type != v.type:
raise TypeError(
( 'an update must have the same type as '
......@@ -188,8 +188,8 @@ def rebuild_collect_shared( outputs
'expression'),
(store_into, update_d[store_into]))
update_val = store_into.filter_update(update_val)
# typically this might be a cast()
# filter_variable ensure smooth conversion of cpu/gpu Types
update_val = store_into.type.filter_variable(update_val)
if update_val.type != store_into.type:
err_msg = ( 'an update must have the same type as the '
'original shared variable(dest, dest.type, '
......
......@@ -152,28 +152,33 @@ class Apply(utils.object2):
:type strict: Bool
:param strict:
If True, the type fields of all the inputs must be equal to the current ones, and
returned outputs are guaranteed to have the same types as self.outputs. If False,
then there's no guarantee that the clone's outputs will have the same types as
self.outputs, and cloning may not even be possible (it depends on the Op).
If True, the type fields of all the inputs must be equal
to the current ones (or compatible, for instance Tensor /
CudaNdarray of the same dtype and broadcastable patterns,
in which case they will be converted into current Type), and
returned outputs are guaranteed to have the same types as
self.outputs. If False, then there's no guarantee that the
clone's outputs will have the same types as self.outputs,
and cloning may not even be possible (it depends on the Op).
:returns: an Apply instance with the same op but different outputs.
"""
remake_node = False
for curr, new in zip(self.inputs, inputs):
new_inputs = inputs[:]
for i, (curr, new) in enumerate(zip(self.inputs, new_inputs)):
if not curr.type == new.type:
if strict:
raise TypeError("Cannot change the type of this input.", ((curr, curr.type),
(new, new.type)))
# If compatible, casts new into curr.type
new_inputs[i] = curr.type.filter_variable(new)
else:
remake_node = True
if remake_node:
new_node = self.op.make_node(*inputs)
new_node = self.op.make_node(*new_inputs)
new_node.tag = copy(self.tag).__update__(new_node.tag)
else:
new_node = self.clone()
new_node.inputs = inputs
new_node.inputs = new_inputs
return new_node
#convenience properties
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论