1 /*
2  * Copyright (C) 2013 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef LATINIME_WEIGHTING_H
18 #define LATINIME_WEIGHTING_H
19 
20 #include "defines.h"
21 #include "suggest/core/dictionary/error_type_utils.h"
22 
23 namespace latinime {
24 
25 class DicNode;
26 class DicTraverseSession;
27 struct DicNode_InputStateG;
28 class MultiBigramMap;
29 
30 class Weighting {
31  public:
32     static void addCostAndForwardInputIndex(const Weighting *const weighting,
33             const CorrectionType correctionType,
34             const DicTraverseSession *const traverseSession,
35             const DicNode *const parentDicNode, DicNode *const dicNode,
36             MultiBigramMap *const multiBigramMap);
37 
38  protected:
39     virtual float getTerminalSpatialCost(const DicTraverseSession *const traverseSession,
40             const DicNode *const dicNode) const = 0;
41 
42     virtual float getOmissionCost(
43          const DicNode *const parentDicNode, const DicNode *const dicNode) const = 0;
44 
45     virtual float getMatchedCost(
46             const DicTraverseSession *const traverseSession, const DicNode *const dicNode,
47             DicNode_InputStateG *inputStateG) const = 0;
48 
49     virtual bool isProximityDicNode(const DicTraverseSession *const traverseSession,
50             const DicNode *const dicNode) const = 0;
51 
52     virtual float getTranspositionCost(
53             const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode,
54             const DicNode *const dicNode) const = 0;
55 
56     virtual float getInsertionCost(
57             const DicTraverseSession *const traverseSession,
58             const DicNode *const parentDicNode, const DicNode *const dicNode) const = 0;
59 
60     virtual float getSpaceOmissionCost(const DicTraverseSession *const traverseSession,
61             const DicNode *const dicNode, DicNode_InputStateG *const inputStateG) const = 0;
62 
63     virtual float getNewWordBigramLanguageCost(
64             const DicTraverseSession *const traverseSession, const DicNode *const dicNode,
65             MultiBigramMap *const multiBigramMap) const = 0;
66 
67     virtual float getCompletionCost(
68             const DicTraverseSession *const traverseSession,
69             const DicNode *const dicNode) const = 0;
70 
71     virtual float getTerminalInsertionCost(
72             const DicTraverseSession *const traverseSession,
73             const DicNode *const dicNode) const = 0;
74 
75     virtual float getTerminalLanguageCost(
76             const DicTraverseSession *const traverseSession, const DicNode *const dicNode,
77             float dicNodeLanguageImprobability) const = 0;
78 
79     virtual bool needsToNormalizeCompoundDistance() const = 0;
80 
81     virtual float getAdditionalProximityCost() const = 0;
82 
83     virtual float getSubstitutionCost() const = 0;
84 
85     virtual float getSpaceSubstitutionCost(const DicTraverseSession *const traverseSession,
86             const DicNode *const dicNode) const = 0;
87 
88     virtual ErrorTypeUtils::ErrorType getErrorType(const CorrectionType correctionType,
89             const DicTraverseSession *const traverseSession,
90             const DicNode *const parentDicNode, const DicNode *const dicNode) const = 0;
91 
Weighting()92     Weighting() {}
~Weighting()93     virtual ~Weighting() {}
94 
95  private:
96     DISALLOW_COPY_AND_ASSIGN(Weighting);
97 
98     static float getSpatialCost(const Weighting *const weighting,
99             const CorrectionType correctionType, const DicTraverseSession *const traverseSession,
100             const DicNode *const parentDicNode, const DicNode *const dicNode,
101             DicNode_InputStateG *const inputStateG);
102     static float getLanguageCost(const Weighting *const weighting,
103             const CorrectionType correctionType, const DicTraverseSession *const traverseSession,
104             const DicNode *const parentDicNode, const DicNode *const dicNode,
105             MultiBigramMap *const multiBigramMap);
106     // TODO: Move to TypingWeighting and GestureWeighting?
107     static int getForwardInputCount(const CorrectionType correctionType);
108 };
109 } // namespace latinime
110 #endif // LATINIME_WEIGHTING_H
111