1 /*
2  * Copyright (C) 2012 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef LATINIME_TYPING_WEIGHTING_H
18 #define LATINIME_TYPING_WEIGHTING_H
19 
20 #include "defines.h"
21 #include "suggest/core/dicnode/dic_node_utils.h"
22 #include "suggest/core/dictionary/error_type_utils.h"
23 #include "suggest/core/layout/touch_position_correction_utils.h"
24 #include "suggest/core/policy/weighting.h"
25 #include "suggest/core/session/dic_traverse_session.h"
26 #include "suggest/policyimpl/typing/scoring_params.h"
27 #include "utils/char_utils.h"
28 
29 namespace latinime {
30 
31 class DicNode;
32 struct DicNode_InputStateG;
33 class MultiBigramMap;
34 
35 class TypingWeighting : public Weighting {
36  public:
getInstance()37     static const TypingWeighting *getInstance() { return &sInstance; }
38 
39  protected:
getTerminalSpatialCost(const DicTraverseSession * const traverseSession,const DicNode * const dicNode)40     float getTerminalSpatialCost(const DicTraverseSession *const traverseSession,
41             const DicNode *const dicNode) const {
42         float cost = 0.0f;
43         if (dicNode->hasMultipleWords()) {
44             cost += ScoringParams::HAS_MULTI_WORD_TERMINAL_COST;
45         }
46         if (dicNode->getProximityCorrectionCount() > 0) {
47             cost += ScoringParams::HAS_PROXIMITY_TERMINAL_COST;
48         }
49         if (dicNode->getEditCorrectionCount() > 0) {
50             cost += ScoringParams::HAS_EDIT_CORRECTION_TERMINAL_COST;
51         }
52         return cost;
53     }
54 
getOmissionCost(const DicNode * const parentDicNode,const DicNode * const dicNode)55     float getOmissionCost(const DicNode *const parentDicNode, const DicNode *const dicNode) const {
56         const bool isZeroCostOmission = parentDicNode->isZeroCostOmission();
57         const bool isIntentionalOmission = parentDicNode->canBeIntentionalOmission();
58         const bool sameCodePoint = dicNode->isSameNodeCodePoint(parentDicNode);
59         // If the traversal omitted the first letter then the dicNode should now be on the second.
60         const bool isFirstLetterOmission = dicNode->getNodeCodePointCount() == 2;
61         float cost = 0.0f;
62         if (isZeroCostOmission) {
63             cost = 0.0f;
64         } else if (isIntentionalOmission) {
65             cost = ScoringParams::INTENTIONAL_OMISSION_COST;
66         } else if (isFirstLetterOmission) {
67             cost = ScoringParams::OMISSION_COST_FIRST_CHAR;
68         } else {
69             cost = sameCodePoint ? ScoringParams::OMISSION_COST_SAME_CHAR
70                     : ScoringParams::OMISSION_COST;
71         }
72         return cost;
73     }
74 
getMatchedCost(const DicTraverseSession * const traverseSession,const DicNode * const dicNode,DicNode_InputStateG * inputStateG)75     float getMatchedCost(const DicTraverseSession *const traverseSession,
76             const DicNode *const dicNode, DicNode_InputStateG *inputStateG) const {
77         const int pointIndex = dicNode->getInputIndex(0);
78         const float normalizedSquaredLength = traverseSession->getProximityInfoState(0)
79                 ->getPointToKeyLength(pointIndex,
80                         CharUtils::toBaseLowerCase(dicNode->getNodeCodePoint()));
81         const float normalizedDistance = TouchPositionCorrectionUtils::getSweetSpotFactor(
82                 traverseSession->isTouchPositionCorrectionEnabled(), normalizedSquaredLength);
83         const float weightedDistance = ScoringParams::DISTANCE_WEIGHT_LENGTH * normalizedDistance;
84 
85         const bool isFirstChar = pointIndex == 0;
86         const bool isProximity = isProximityDicNode(traverseSession, dicNode);
87         float cost = isProximity ? (isFirstChar ? ScoringParams::FIRST_CHAR_PROXIMITY_COST
88                 : ScoringParams::PROXIMITY_COST) : 0.0f;
89         if (isProximity && dicNode->getProximityCorrectionCount() == 0) {
90             cost += ScoringParams::FIRST_PROXIMITY_COST;
91         }
92         if (dicNode->getNodeCodePointCount() == 2) {
93             // At the second character of the current word, we check if the first char is uppercase
94             // and the word is a second or later word of a multiple word suggestion. We demote it
95             // if so.
96             const bool isSecondOrLaterWordFirstCharUppercase =
97                     dicNode->hasMultipleWords() && dicNode->isFirstCharUppercase();
98             if (isSecondOrLaterWordFirstCharUppercase) {
99                 cost += ScoringParams::COST_SECOND_OR_LATER_WORD_FIRST_CHAR_UPPERCASE;
100             }
101         }
102         return weightedDistance + cost;
103     }
104 
isProximityDicNode(const DicTraverseSession * const traverseSession,const DicNode * const dicNode)105     bool isProximityDicNode(const DicTraverseSession *const traverseSession,
106             const DicNode *const dicNode) const {
107         const int pointIndex = dicNode->getInputIndex(0);
108         const int primaryCodePoint = CharUtils::toBaseLowerCase(
109                 traverseSession->getProximityInfoState(0)->getPrimaryCodePointAt(pointIndex));
110         const int dicNodeChar = CharUtils::toBaseLowerCase(dicNode->getNodeCodePoint());
111         return primaryCodePoint != dicNodeChar;
112     }
113 
getTranspositionCost(const DicTraverseSession * const traverseSession,const DicNode * const parentDicNode,const DicNode * const dicNode)114     float getTranspositionCost(const DicTraverseSession *const traverseSession,
115             const DicNode *const parentDicNode, const DicNode *const dicNode) const {
116         const int16_t parentPointIndex = parentDicNode->getInputIndex(0);
117         const int prevCodePoint = parentDicNode->getNodeCodePoint();
118         const float distance1 = traverseSession->getProximityInfoState(0)->getPointToKeyLength(
119                 parentPointIndex + 1, CharUtils::toBaseLowerCase(prevCodePoint));
120         const int codePoint = dicNode->getNodeCodePoint();
121         const float distance2 = traverseSession->getProximityInfoState(0)->getPointToKeyLength(
122                 parentPointIndex, CharUtils::toBaseLowerCase(codePoint));
123         const float distance = distance1 + distance2;
124         const float weightedLengthDistance =
125                 distance * ScoringParams::DISTANCE_WEIGHT_LENGTH;
126         return ScoringParams::TRANSPOSITION_COST + weightedLengthDistance;
127     }
128 
getInsertionCost(const DicTraverseSession * const traverseSession,const DicNode * const parentDicNode,const DicNode * const dicNode)129     float getInsertionCost(const DicTraverseSession *const traverseSession,
130             const DicNode *const parentDicNode, const DicNode *const dicNode) const {
131         const int16_t insertedPointIndex = parentDicNode->getInputIndex(0);
132         const int prevCodePoint = traverseSession->getProximityInfoState(0)->getPrimaryCodePointAt(
133                 insertedPointIndex);
134         const int currentCodePoint = dicNode->getNodeCodePoint();
135         const bool sameCodePoint = prevCodePoint == currentCodePoint;
136         const bool existsAdjacentProximityChars = traverseSession->getProximityInfoState(0)
137                 ->existsAdjacentProximityChars(insertedPointIndex);
138         const float dist = traverseSession->getProximityInfoState(0)->getPointToKeyLength(
139                 insertedPointIndex + 1, CharUtils::toBaseLowerCase(dicNode->getNodeCodePoint()));
140         const float weightedDistance = dist * ScoringParams::DISTANCE_WEIGHT_LENGTH;
141         const bool singleChar = dicNode->getNodeCodePointCount() == 1;
142         float cost = (singleChar ? ScoringParams::INSERTION_COST_FIRST_CHAR : 0.0f);
143         if (sameCodePoint) {
144             cost += ScoringParams::INSERTION_COST_SAME_CHAR;
145         } else if (existsAdjacentProximityChars) {
146             cost += ScoringParams::INSERTION_COST_PROXIMITY_CHAR;
147         } else {
148             cost += ScoringParams::INSERTION_COST;
149         }
150         return cost + weightedDistance;
151     }
152 
getSpaceOmissionCost(const DicTraverseSession * const traverseSession,const DicNode * const dicNode,DicNode_InputStateG * inputStateG)153     float getSpaceOmissionCost(const DicTraverseSession *const traverseSession,
154             const DicNode *const dicNode, DicNode_InputStateG *inputStateG) const {
155         const float cost = ScoringParams::SPACE_OMISSION_COST;
156         return cost * traverseSession->getMultiWordCostMultiplier();
157     }
158 
getNewWordBigramLanguageCost(const DicTraverseSession * const traverseSession,const DicNode * const dicNode,MultiBigramMap * const multiBigramMap)159     float getNewWordBigramLanguageCost(const DicTraverseSession *const traverseSession,
160             const DicNode *const dicNode,
161             MultiBigramMap *const multiBigramMap) const {
162         return DicNodeUtils::getBigramNodeImprobability(
163                 traverseSession->getDictionaryStructurePolicy(),
164                 dicNode, multiBigramMap) * ScoringParams::DISTANCE_WEIGHT_LANGUAGE;
165     }
166 
getCompletionCost(const DicTraverseSession * const traverseSession,const DicNode * const dicNode)167     float getCompletionCost(const DicTraverseSession *const traverseSession,
168             const DicNode *const dicNode) const {
169         // The auto completion starts when the input index is same as the input size
170         const bool firstCompletion = dicNode->getInputIndex(0)
171                 == traverseSession->getInputSize();
172         // TODO: Change the cost for the first completion for the gesture?
173         const float cost = firstCompletion ? ScoringParams::COST_FIRST_COMPLETION
174                 : ScoringParams::COST_COMPLETION;
175         return cost;
176     }
177 
getTerminalLanguageCost(const DicTraverseSession * const traverseSession,const DicNode * const dicNode,const float dicNodeLanguageImprobability)178     float getTerminalLanguageCost(const DicTraverseSession *const traverseSession,
179             const DicNode *const dicNode, const float dicNodeLanguageImprobability) const {
180         return dicNodeLanguageImprobability * ScoringParams::DISTANCE_WEIGHT_LANGUAGE;
181     }
182 
getTerminalInsertionCost(const DicTraverseSession * const traverseSession,const DicNode * const dicNode)183     float getTerminalInsertionCost(const DicTraverseSession *const traverseSession,
184             const DicNode *const dicNode) const {
185         const int inputIndex = dicNode->getInputIndex(0);
186         const int inputSize = traverseSession->getInputSize();
187         ASSERT(inputIndex < inputSize);
188         // TODO: Implement more efficient logic
189         return  ScoringParams::TERMINAL_INSERTION_COST * (inputSize - inputIndex);
190     }
191 
needsToNormalizeCompoundDistance()192     AK_FORCE_INLINE bool needsToNormalizeCompoundDistance() const {
193         return false;
194     }
195 
getAdditionalProximityCost()196     AK_FORCE_INLINE float getAdditionalProximityCost() const {
197         return ScoringParams::ADDITIONAL_PROXIMITY_COST;
198     }
199 
getSubstitutionCost()200     AK_FORCE_INLINE float getSubstitutionCost() const {
201         return ScoringParams::SUBSTITUTION_COST;
202     }
203 
getSpaceSubstitutionCost(const DicTraverseSession * const traverseSession,const DicNode * const dicNode)204     AK_FORCE_INLINE float getSpaceSubstitutionCost(const DicTraverseSession *const traverseSession,
205             const DicNode *const dicNode) const {
206         const int inputIndex = dicNode->getInputIndex(0);
207         const float distanceToSpaceKey = traverseSession->getProximityInfoState(0)
208                 ->getPointToKeyLength(inputIndex, KEYCODE_SPACE);
209         const float cost = ScoringParams::SPACE_SUBSTITUTION_COST * distanceToSpaceKey;
210         return cost * traverseSession->getMultiWordCostMultiplier();
211     }
212 
213     ErrorTypeUtils::ErrorType getErrorType(const CorrectionType correctionType,
214             const DicTraverseSession *const traverseSession,
215             const DicNode *const parentDicNode, const DicNode *const dicNode) const;
216 
217  private:
218     DISALLOW_COPY_AND_ASSIGN(TypingWeighting);
219     static const TypingWeighting sInstance;
220 
TypingWeighting()221     TypingWeighting() {}
~TypingWeighting()222     ~TypingWeighting() {}
223 };
224 } // namespace latinime
225 #endif // LATINIME_TYPING_WEIGHTING_H
226