1 /*
2 * Copyright (C) 2013 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "suggest/core/policy/weighting.h"
18
19 #include "defines.h"
20 #include "suggest/core/dicnode/dic_node.h"
21 #include "suggest/core/dicnode/dic_node_profiler.h"
22 #include "suggest/core/dicnode/dic_node_utils.h"
23 #include "suggest/core/dictionary/error_type_utils.h"
24 #include "suggest/core/session/dic_traverse_session.h"
25
26 namespace latinime {
27
28 class MultiBigramMap;
29
profile(const CorrectionType correctionType,DicNode * const node)30 static inline void profile(const CorrectionType correctionType, DicNode *const node) {
31 #if DEBUG_DICT
32 switch (correctionType) {
33 case CT_OMISSION:
34 PROF_OMISSION(node->mProfiler);
35 return;
36 case CT_ADDITIONAL_PROXIMITY:
37 PROF_ADDITIONAL_PROXIMITY(node->mProfiler);
38 return;
39 case CT_SUBSTITUTION:
40 PROF_SUBSTITUTION(node->mProfiler);
41 return;
42 case CT_NEW_WORD_SPACE_OMISSION:
43 PROF_NEW_WORD(node->mProfiler);
44 return;
45 case CT_MATCH:
46 PROF_MATCH(node->mProfiler);
47 return;
48 case CT_COMPLETION:
49 PROF_COMPLETION(node->mProfiler);
50 return;
51 case CT_TERMINAL:
52 PROF_TERMINAL(node->mProfiler);
53 return;
54 case CT_TERMINAL_INSERTION:
55 PROF_TERMINAL_INSERTION(node->mProfiler);
56 return;
57 case CT_NEW_WORD_SPACE_SUBSTITUTION:
58 PROF_SPACE_SUBSTITUTION(node->mProfiler);
59 return;
60 case CT_INSERTION:
61 PROF_INSERTION(node->mProfiler);
62 return;
63 case CT_TRANSPOSITION:
64 PROF_TRANSPOSITION(node->mProfiler);
65 return;
66 default:
67 // do nothing
68 return;
69 }
70 #else
71 // do nothing
72 #endif
73 }
74
addCostAndForwardInputIndex(const Weighting * const weighting,const CorrectionType correctionType,const DicTraverseSession * const traverseSession,const DicNode * const parentDicNode,DicNode * const dicNode,MultiBigramMap * const multiBigramMap)75 /* static */ void Weighting::addCostAndForwardInputIndex(const Weighting *const weighting,
76 const CorrectionType correctionType, const DicTraverseSession *const traverseSession,
77 const DicNode *const parentDicNode, DicNode *const dicNode,
78 MultiBigramMap *const multiBigramMap) {
79 const int inputSize = traverseSession->getInputSize();
80 DicNode_InputStateG inputStateG;
81 inputStateG.mNeedsToUpdateInputStateG = false; // Don't use input info by default
82 const float spatialCost = Weighting::getSpatialCost(weighting, correctionType,
83 traverseSession, parentDicNode, dicNode, &inputStateG);
84 const float languageCost = Weighting::getLanguageCost(weighting, correctionType,
85 traverseSession, parentDicNode, dicNode, multiBigramMap);
86 const ErrorTypeUtils::ErrorType errorType = weighting->getErrorType(correctionType,
87 traverseSession, parentDicNode, dicNode);
88 profile(correctionType, dicNode);
89 if (inputStateG.mNeedsToUpdateInputStateG) {
90 dicNode->updateInputIndexG(&inputStateG);
91 } else {
92 dicNode->forwardInputIndex(0, getForwardInputCount(correctionType),
93 (correctionType == CT_TRANSPOSITION));
94 }
95 dicNode->addCost(spatialCost, languageCost, weighting->needsToNormalizeCompoundDistance(),
96 inputSize, errorType);
97 if (CT_NEW_WORD_SPACE_OMISSION == correctionType) {
98 // When we are on a terminal, we save the current distance for evaluating
99 // when to auto-commit partial suggestions.
100 dicNode->saveNormalizedCompoundDistanceAfterFirstWordIfNoneYet();
101 }
102 }
103
getSpatialCost(const Weighting * const weighting,const CorrectionType correctionType,const DicTraverseSession * const traverseSession,const DicNode * const parentDicNode,const DicNode * const dicNode,DicNode_InputStateG * const inputStateG)104 /* static */ float Weighting::getSpatialCost(const Weighting *const weighting,
105 const CorrectionType correctionType, const DicTraverseSession *const traverseSession,
106 const DicNode *const parentDicNode, const DicNode *const dicNode,
107 DicNode_InputStateG *const inputStateG) {
108 switch(correctionType) {
109 case CT_OMISSION:
110 return weighting->getOmissionCost(parentDicNode, dicNode);
111 case CT_ADDITIONAL_PROXIMITY:
112 // only used for typing
113 // TODO: Quit calling getMatchedCost().
114 return weighting->getAdditionalProximityCost()
115 + weighting->getMatchedCost(traverseSession, dicNode, inputStateG);
116 case CT_SUBSTITUTION:
117 // only used for typing
118 // TODO: Quit calling getMatchedCost().
119 return weighting->getSubstitutionCost()
120 + weighting->getMatchedCost(traverseSession, dicNode, inputStateG);
121 case CT_NEW_WORD_SPACE_OMISSION:
122 return weighting->getSpaceOmissionCost(traverseSession, dicNode, inputStateG);
123 case CT_MATCH:
124 return weighting->getMatchedCost(traverseSession, dicNode, inputStateG);
125 case CT_COMPLETION:
126 return weighting->getCompletionCost(traverseSession, dicNode);
127 case CT_TERMINAL:
128 return weighting->getTerminalSpatialCost(traverseSession, dicNode);
129 case CT_TERMINAL_INSERTION:
130 return weighting->getTerminalInsertionCost(traverseSession, dicNode);
131 case CT_NEW_WORD_SPACE_SUBSTITUTION:
132 return weighting->getSpaceSubstitutionCost(traverseSession, dicNode);
133 case CT_INSERTION:
134 return weighting->getInsertionCost(traverseSession, parentDicNode, dicNode);
135 case CT_TRANSPOSITION:
136 return weighting->getTranspositionCost(traverseSession, parentDicNode, dicNode);
137 default:
138 return 0.0f;
139 }
140 }
141
getLanguageCost(const Weighting * const weighting,const CorrectionType correctionType,const DicTraverseSession * const traverseSession,const DicNode * const parentDicNode,const DicNode * const dicNode,MultiBigramMap * const multiBigramMap)142 /* static */ float Weighting::getLanguageCost(const Weighting *const weighting,
143 const CorrectionType correctionType, const DicTraverseSession *const traverseSession,
144 const DicNode *const parentDicNode, const DicNode *const dicNode,
145 MultiBigramMap *const multiBigramMap) {
146 switch(correctionType) {
147 case CT_OMISSION:
148 return 0.0f;
149 case CT_SUBSTITUTION:
150 return 0.0f;
151 case CT_NEW_WORD_SPACE_OMISSION:
152 return weighting->getNewWordBigramLanguageCost(
153 traverseSession, parentDicNode, multiBigramMap);
154 case CT_MATCH:
155 return 0.0f;
156 case CT_COMPLETION:
157 return 0.0f;
158 case CT_TERMINAL: {
159 const float languageImprobability =
160 DicNodeUtils::getBigramNodeImprobability(
161 traverseSession->getDictionaryStructurePolicy(), dicNode, multiBigramMap);
162 return weighting->getTerminalLanguageCost(traverseSession, dicNode, languageImprobability);
163 }
164 case CT_TERMINAL_INSERTION:
165 return 0.0f;
166 case CT_NEW_WORD_SPACE_SUBSTITUTION:
167 return weighting->getNewWordBigramLanguageCost(
168 traverseSession, parentDicNode, multiBigramMap);
169 case CT_INSERTION:
170 return 0.0f;
171 case CT_TRANSPOSITION:
172 return 0.0f;
173 default:
174 return 0.0f;
175 }
176 }
177
getForwardInputCount(const CorrectionType correctionType)178 /* static */ int Weighting::getForwardInputCount(const CorrectionType correctionType) {
179 switch(correctionType) {
180 case CT_OMISSION:
181 return 0;
182 case CT_ADDITIONAL_PROXIMITY:
183 return 1;
184 case CT_SUBSTITUTION:
185 return 1;
186 case CT_NEW_WORD_SPACE_OMISSION:
187 return 0;
188 case CT_MATCH:
189 return 1;
190 case CT_COMPLETION:
191 return 1;
192 case CT_TERMINAL:
193 return 0;
194 case CT_TERMINAL_INSERTION:
195 return 1;
196 case CT_NEW_WORD_SPACE_SUBSTITUTION:
197 return 1;
198 case CT_INSERTION:
199 return 2; /* look ahead + skip the current char */
200 case CT_TRANSPOSITION:
201 return 2; /* look ahead + skip the current char */
202 default:
203 return 0;
204 }
205 }
206 } // namespace latinime
207