1 /*
2  * Copyright (C) 2013 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "suggest/core/policy/weighting.h"
18 
19 #include "defines.h"
20 #include "suggest/core/dicnode/dic_node.h"
21 #include "suggest/core/dicnode/dic_node_profiler.h"
22 #include "suggest/core/dicnode/dic_node_utils.h"
23 #include "suggest/core/dictionary/error_type_utils.h"
24 #include "suggest/core/session/dic_traverse_session.h"
25 
26 namespace latinime {
27 
28 class MultiBigramMap;
29 
profile(const CorrectionType correctionType,DicNode * const node)30 static inline void profile(const CorrectionType correctionType, DicNode *const node) {
31 #if DEBUG_DICT
32     switch (correctionType) {
33     case CT_OMISSION:
34         PROF_OMISSION(node->mProfiler);
35         return;
36     case CT_ADDITIONAL_PROXIMITY:
37         PROF_ADDITIONAL_PROXIMITY(node->mProfiler);
38         return;
39     case CT_SUBSTITUTION:
40         PROF_SUBSTITUTION(node->mProfiler);
41         return;
42     case CT_NEW_WORD_SPACE_OMISSION:
43         PROF_NEW_WORD(node->mProfiler);
44         return;
45     case CT_MATCH:
46         PROF_MATCH(node->mProfiler);
47         return;
48     case CT_COMPLETION:
49         PROF_COMPLETION(node->mProfiler);
50         return;
51     case CT_TERMINAL:
52         PROF_TERMINAL(node->mProfiler);
53         return;
54     case CT_TERMINAL_INSERTION:
55         PROF_TERMINAL_INSERTION(node->mProfiler);
56         return;
57     case CT_NEW_WORD_SPACE_SUBSTITUTION:
58         PROF_SPACE_SUBSTITUTION(node->mProfiler);
59         return;
60     case CT_INSERTION:
61         PROF_INSERTION(node->mProfiler);
62         return;
63     case CT_TRANSPOSITION:
64         PROF_TRANSPOSITION(node->mProfiler);
65         return;
66     default:
67         // do nothing
68         return;
69     }
70 #else
71     // do nothing
72 #endif
73 }
74 
addCostAndForwardInputIndex(const Weighting * const weighting,const CorrectionType correctionType,const DicTraverseSession * const traverseSession,const DicNode * const parentDicNode,DicNode * const dicNode,MultiBigramMap * const multiBigramMap)75 /* static */ void Weighting::addCostAndForwardInputIndex(const Weighting *const weighting,
76         const CorrectionType correctionType, const DicTraverseSession *const traverseSession,
77         const DicNode *const parentDicNode, DicNode *const dicNode,
78         MultiBigramMap *const multiBigramMap) {
79     const int inputSize = traverseSession->getInputSize();
80     DicNode_InputStateG inputStateG;
81     inputStateG.mNeedsToUpdateInputStateG = false; // Don't use input info by default
82     const float spatialCost = Weighting::getSpatialCost(weighting, correctionType,
83             traverseSession, parentDicNode, dicNode, &inputStateG);
84     const float languageCost = Weighting::getLanguageCost(weighting, correctionType,
85             traverseSession, parentDicNode, dicNode, multiBigramMap);
86     const ErrorTypeUtils::ErrorType errorType = weighting->getErrorType(correctionType,
87             traverseSession, parentDicNode, dicNode);
88     profile(correctionType, dicNode);
89     if (inputStateG.mNeedsToUpdateInputStateG) {
90         dicNode->updateInputIndexG(&inputStateG);
91     } else {
92         dicNode->forwardInputIndex(0, getForwardInputCount(correctionType),
93                 (correctionType == CT_TRANSPOSITION));
94     }
95     dicNode->addCost(spatialCost, languageCost, weighting->needsToNormalizeCompoundDistance(),
96             inputSize, errorType);
97     if (CT_NEW_WORD_SPACE_OMISSION == correctionType) {
98         // When we are on a terminal, we save the current distance for evaluating
99         // when to auto-commit partial suggestions.
100         dicNode->saveNormalizedCompoundDistanceAfterFirstWordIfNoneYet();
101     }
102 }
103 
getSpatialCost(const Weighting * const weighting,const CorrectionType correctionType,const DicTraverseSession * const traverseSession,const DicNode * const parentDicNode,const DicNode * const dicNode,DicNode_InputStateG * const inputStateG)104 /* static */ float Weighting::getSpatialCost(const Weighting *const weighting,
105         const CorrectionType correctionType, const DicTraverseSession *const traverseSession,
106         const DicNode *const parentDicNode, const DicNode *const dicNode,
107         DicNode_InputStateG *const inputStateG) {
108     switch(correctionType) {
109     case CT_OMISSION:
110         return weighting->getOmissionCost(parentDicNode, dicNode);
111     case CT_ADDITIONAL_PROXIMITY:
112         // only used for typing
113         // TODO: Quit calling getMatchedCost().
114         return weighting->getAdditionalProximityCost()
115                 + weighting->getMatchedCost(traverseSession, dicNode, inputStateG);
116     case CT_SUBSTITUTION:
117         // only used for typing
118         // TODO: Quit calling getMatchedCost().
119         return weighting->getSubstitutionCost()
120                 + weighting->getMatchedCost(traverseSession, dicNode, inputStateG);
121     case CT_NEW_WORD_SPACE_OMISSION:
122         return weighting->getSpaceOmissionCost(traverseSession, dicNode, inputStateG);
123     case CT_MATCH:
124         return weighting->getMatchedCost(traverseSession, dicNode, inputStateG);
125     case CT_COMPLETION:
126         return weighting->getCompletionCost(traverseSession, dicNode);
127     case CT_TERMINAL:
128         return weighting->getTerminalSpatialCost(traverseSession, dicNode);
129     case CT_TERMINAL_INSERTION:
130         return weighting->getTerminalInsertionCost(traverseSession, dicNode);
131     case CT_NEW_WORD_SPACE_SUBSTITUTION:
132         return weighting->getSpaceSubstitutionCost(traverseSession, dicNode);
133     case CT_INSERTION:
134         return weighting->getInsertionCost(traverseSession, parentDicNode, dicNode);
135     case CT_TRANSPOSITION:
136         return weighting->getTranspositionCost(traverseSession, parentDicNode, dicNode);
137     default:
138         return 0.0f;
139     }
140 }
141 
getLanguageCost(const Weighting * const weighting,const CorrectionType correctionType,const DicTraverseSession * const traverseSession,const DicNode * const parentDicNode,const DicNode * const dicNode,MultiBigramMap * const multiBigramMap)142 /* static */ float Weighting::getLanguageCost(const Weighting *const weighting,
143         const CorrectionType correctionType, const DicTraverseSession *const traverseSession,
144         const DicNode *const parentDicNode, const DicNode *const dicNode,
145         MultiBigramMap *const multiBigramMap) {
146     switch(correctionType) {
147     case CT_OMISSION:
148         return 0.0f;
149     case CT_SUBSTITUTION:
150         return 0.0f;
151     case CT_NEW_WORD_SPACE_OMISSION:
152         return weighting->getNewWordBigramLanguageCost(
153                 traverseSession, parentDicNode, multiBigramMap);
154     case CT_MATCH:
155         return 0.0f;
156     case CT_COMPLETION:
157         return 0.0f;
158     case CT_TERMINAL: {
159         const float languageImprobability =
160                 DicNodeUtils::getBigramNodeImprobability(
161                         traverseSession->getDictionaryStructurePolicy(), dicNode, multiBigramMap);
162         return weighting->getTerminalLanguageCost(traverseSession, dicNode, languageImprobability);
163     }
164     case CT_TERMINAL_INSERTION:
165         return 0.0f;
166     case CT_NEW_WORD_SPACE_SUBSTITUTION:
167         return weighting->getNewWordBigramLanguageCost(
168                 traverseSession, parentDicNode, multiBigramMap);
169     case CT_INSERTION:
170         return 0.0f;
171     case CT_TRANSPOSITION:
172         return 0.0f;
173     default:
174         return 0.0f;
175     }
176 }
177 
getForwardInputCount(const CorrectionType correctionType)178 /* static */ int Weighting::getForwardInputCount(const CorrectionType correctionType) {
179     switch(correctionType) {
180         case CT_OMISSION:
181             return 0;
182         case CT_ADDITIONAL_PROXIMITY:
183             return 1;
184         case CT_SUBSTITUTION:
185             return 1;
186         case CT_NEW_WORD_SPACE_OMISSION:
187             return 0;
188         case CT_MATCH:
189             return 1;
190         case CT_COMPLETION:
191             return 1;
192         case CT_TERMINAL:
193             return 0;
194         case CT_TERMINAL_INSERTION:
195             return 1;
196         case CT_NEW_WORD_SPACE_SUBSTITUTION:
197             return 1;
198         case CT_INSERTION:
199             return 2; /* look ahead + skip the current char */
200         case CT_TRANSPOSITION:
201             return 2; /* look ahead + skip the current char */
202         default:
203             return 0;
204     }
205 }
206 }  // namespace latinime
207