提交 0380b760 authored 作者: Frederic Bastien's avatar Frederic Bastien

allow to put int and float into input and output pattern of PatternSub

上级 efa90388
...@@ -14,6 +14,7 @@ import op ...@@ -14,6 +14,7 @@ import op
from copy import copy from copy import copy
from theano.gof.python25 import any, all from theano.gof.python25 import any, all
from theano.configparser import AddConfigVar, BoolParam, config from theano.configparser import AddConfigVar, BoolParam, config
import theano
#if sys.version_info[:2] >= (2,5): #if sys.version_info[:2] >= (2,5):
# from collections import defaultdict # from collections import defaultdict
...@@ -517,10 +518,14 @@ class PatternSub(LocalOptimizer): ...@@ -517,10 +518,14 @@ class PatternSub(LocalOptimizer):
sub_pattern ::= input_pattern sub_pattern ::= input_pattern
sub_pattern ::= string sub_pattern ::= string
sub_pattern ::= a Constant instance sub_pattern ::= a Constant instance
sub_pattern ::= int
sub_pattern ::= float
constraint ::= lambda env, expr: additional matching condition constraint ::= lambda env, expr: additional matching condition
output_pattern ::= (op, <output_pattern1>, <output_pattern2>, ...) output_pattern ::= (op, <output_pattern1>, <output_pattern2>, ...)
output_pattern ::= string output_pattern ::= string
output_pattern ::= int
output_pattern ::= float
Each string in the input pattern is a variable that will be set to Each string in the input pattern is a variable that will be set to
whatever expression is found in its place. If the same string is whatever expression is found in its place. If the same string is
...@@ -635,6 +640,11 @@ class PatternSub(LocalOptimizer): ...@@ -635,6 +640,11 @@ class PatternSub(LocalOptimizer):
return retry_with_equiv() return retry_with_equiv()
else: else:
u = u.merge(expr, v) u = u.merge(expr, v)
elif isinstance(pattern, (int, float)) and isinstance(expr, graph.Constant):
if all(theano.tensor.constant(pattern).value==expr.value):
return u
else:
return retry_with_equiv()
elif isinstance(pattern, graph.Constant) and isinstance(expr, graph.Constant) and pattern.equals(expr): elif isinstance(pattern, graph.Constant) and isinstance(expr, graph.Constant) and pattern.equals(expr):
return u return u
else: else:
...@@ -647,6 +657,8 @@ class PatternSub(LocalOptimizer): ...@@ -647,6 +657,8 @@ class PatternSub(LocalOptimizer):
return pattern[0](*args) return pattern[0](*args)
elif isinstance(pattern, str): elif isinstance(pattern, str):
return u[unify.Var(pattern)] return u[unify.Var(pattern)]
elif isinstance(pattern, (int,float)):
return pattern
else: else:
return pattern.clone() return pattern.clone()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论