1 /*
2  * Copyright (C) 2021-2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *    http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "pake_v1_client_task.h"
17 #include "device_auth_defines.h"
18 #include "hc_log.h"
19 #include "hc_types.h"
20 #include "pake_v1_client_protocol_task.h"
21 #include "pake_v1_protocol_task_common.h"
22 #include "pake_task_common.h"
23 #include "standard_client_bind_exchange_task.h"
24 
GetPakeV1ClientTaskType(const struct SubTaskBaseT * task)25 static int GetPakeV1ClientTaskType(const struct SubTaskBaseT *task)
26 {
27     PakeV1ClientTask *realTask = (PakeV1ClientTask *)task;
28     if (realTask->curTask == NULL) {
29         LOGE("CurTask is null.");
30         return TASK_TYPE_NONE;
31     }
32     return realTask->curTask->getCurTaskType();
33 }
34 
DestroyPakeV1ClientTask(struct SubTaskBaseT * task)35 static void DestroyPakeV1ClientTask(struct SubTaskBaseT *task)
36 {
37     PakeV1ClientTask *innerTask = (PakeV1ClientTask *)task;
38     if (innerTask == NULL) {
39         return;
40     }
41 
42     DestroyDasPakeV1Params(&(innerTask->params));
43     if (innerTask->curTask != NULL) {
44         innerTask->curTask->destroyTask(innerTask->curTask);
45     }
46     HcFree(innerTask);
47 }
48 
CreateAndProcessNextBindTask(PakeV1ClientTask * realTask,const CJson * in,CJson * out,int * status)49 static int CreateAndProcessNextBindTask(PakeV1ClientTask *realTask, const CJson *in, CJson *out, int *status)
50 {
51     realTask->curTask->destroyTask(realTask->curTask);
52     realTask->curTask = CreateStandardBindExchangeClientTask();
53     if (realTask->curTask == NULL) {
54         LOGE("CreateStandardBindExchangeClientTask failed.");
55         return HC_ERROR;
56     }
57     int res = realTask->curTask->process(realTask->curTask, &(realTask->params), in, out, status);
58     if (res != HC_SUCCESS) {
59         LOGE("Process StandardBindExchangeClientTask failed.");
60     }
61     return res;
62 }
63 
CreateNextTask(PakeV1ClientTask * realTask,const CJson * in,CJson * out,int * status)64 static int CreateNextTask(PakeV1ClientTask *realTask, const CJson *in, CJson *out, int *status)
65 {
66     int res = HC_SUCCESS;
67     switch (realTask->params.opCode) {
68         case OP_BIND:
69             if (realTask->curTask->getCurTaskType() == TASK_TYPE_BIND_STANDARD_EXCHANGE) {
70                 break;
71             }
72             *status = CONTINUE;
73             res = CreateAndProcessNextBindTask(realTask, in, out, status);
74             break;
75         case AUTH_KEY_AGREEMENT:
76         case AUTHENTICATE:
77             break;
78         default:
79             LOGE("Unsupported opCode: %d.", realTask->params.opCode);
80             res = HC_ERR_NOT_SUPPORT;
81     }
82     if (res != HC_SUCCESS) {
83         LOGE("Create and process next task failed, opcode: %d, res: %d.", realTask->params.opCode, res);
84         return res;
85     }
86     if (*status != FINISH) {
87         return res;
88     }
89     res = SendResultToSelf(&realTask->params, out);
90     if (res != HC_SUCCESS) {
91         LOGE("SendResultToSelf failed, res: %d", res);
92         return res;
93     }
94     LOGI("End client task successfully, opcode: %d.", realTask->params.opCode);
95     return res;
96 }
97 
Process(struct SubTaskBaseT * task,const CJson * in,CJson * out,int * status)98 static int Process(struct SubTaskBaseT *task, const CJson *in, CJson *out, int *status)
99 {
100     PakeV1ClientTask *realTask = (PakeV1ClientTask *)task;
101     if (realTask->curTask == NULL) {
102         LOGE("CurTask is null.");
103         return HC_ERR_NULL_PTR;
104     }
105 
106     realTask->params.baseParams.supportedPakeAlg = GetSupportedPakeAlg(&(realTask->taskBase.curVersion), PAKE_V1);
107     realTask->params.isPskSupported = IsSupportedPsk(&(realTask->taskBase.curVersion));
108     int res = realTask->curTask->process(realTask->curTask, &(realTask->params), in, out, status);
109     if (res != HC_SUCCESS) {
110         LOGE("CurTask processes failed, res: %x.", res);
111         return res;
112     }
113     if (*status != FINISH) {
114         return res;
115     }
116     return CreateNextTask(realTask, in, out, status);
117 }
118 
CreatePakeV1ClientTask(const CJson * in)119 SubTaskBase *CreatePakeV1ClientTask(const CJson *in)
120 {
121     PakeV1ClientTask *task = (PakeV1ClientTask *)HcMalloc(sizeof(PakeV1ClientTask), 0);
122     if (task == NULL) {
123         LOGE("Malloc for PakeV1ClientTask failed.");
124         return NULL;
125     }
126 
127     task->taskBase.getTaskType = GetPakeV1ClientTaskType;
128     task->taskBase.destroyTask = DestroyPakeV1ClientTask;
129     task->taskBase.process = Process;
130 
131     int res = InitDasPakeV1Params(&(task->params), in);
132     if (res != HC_SUCCESS) {
133         LOGE("Init das pake params failed, res: %d.", res);
134         DestroyPakeV1ClientTask((struct SubTaskBaseT *)task);
135         return NULL;
136     }
137     task->curTask = CreatePakeV1ProtocolClientTask();
138     if (task->curTask == NULL) {
139         LOGE("Create pake protocol client task failed.");
140         DestroyPakeV1ClientTask((struct SubTaskBaseT *)task);
141         return NULL;
142     }
143     return (SubTaskBase *)task;
144 }
145