1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <google/protobuf/text_format.h>
18 
19 #include <algorithm>
20 #include <fstream>
21 #include <iostream>
22 #include <random>
23 #include <string>
24 #include <string_view>
25 #include <utility>
26 #include <vector>
27 
28 #include "Converter.h"
29 
30 namespace android::nn::fuzz {
31 namespace {
32 
33 using namespace test_helper;
34 using namespace android_nn_fuzz;
35 
convert(TestOperandType type)36 OperandType convert(TestOperandType type) {
37     return static_cast<OperandType>(type);
38 }
39 
convert(TestOperationType type)40 OperationType convert(TestOperationType type) {
41     return static_cast<OperationType>(type);
42 }
43 
convert(TestOperandLifeTime lifetime)44 Operand::LifeTime convert(TestOperandLifeTime lifetime) {
45     return static_cast<Operand::LifeTime>(lifetime);
46 }
47 
convert(const std::vector<float> & scales)48 Scales convert(const std::vector<float>& scales) {
49     Scales protoScales;
50     for (float scale : scales) {
51         protoScales.add_scale(scale);
52     }
53     return protoScales;
54 }
55 
convert(const TestSymmPerChannelQuantParams & params)56 SymmPerChannelQuantParams convert(const TestSymmPerChannelQuantParams& params) {
57     SymmPerChannelQuantParams protoParams;
58     *protoParams.mutable_scales() = convert(params.scales);
59     protoParams.set_channel_dim(params.channelDim);
60     return protoParams;
61 }
62 
convertDimensions(const std::vector<uint32_t> & dimensions)63 Dimensions convertDimensions(const std::vector<uint32_t>& dimensions) {
64     Dimensions protoDimensions;
65     for (uint32_t dimension : dimensions) {
66         protoDimensions.add_dimension(dimension);
67     }
68     return protoDimensions;
69 }
70 
getHashValue(const TestBuffer & buffer)71 uint32_t getHashValue(const TestBuffer& buffer) {
72     const char* ptr = buffer.get<char>();
73     const size_t size = buffer.size();
74     const std::string_view view(ptr, size);
75     const size_t value = std::hash<std::string_view>{}(view);
76     return static_cast<uint32_t>(value & 0xFFFFFFFF);
77 }
78 
convert(bool noValue,const TestBuffer & buffer)79 Buffer convert(bool noValue, const TestBuffer& buffer) {
80     Buffer protoBuffer;
81     const uint32_t randomSeed = (noValue ? 0 : getHashValue(buffer));
82     protoBuffer.set_random_seed(randomSeed);
83     return protoBuffer;
84 }
85 
convert(const TestOperand & operand)86 Operand convert(const TestOperand& operand) {
87     Operand protoOperand;
88     protoOperand.set_type(convert(operand.type));
89     *protoOperand.mutable_dimensions() = convertDimensions(operand.dimensions);
90     protoOperand.set_scale(operand.scale);
91     protoOperand.set_zero_point(operand.zeroPoint);
92     protoOperand.set_lifetime(convert(operand.lifetime));
93     *protoOperand.mutable_channel_quant() = convert(operand.channelQuant);
94     const bool noValue = (operand.lifetime == TestOperandLifeTime::NO_VALUE);
95     *protoOperand.mutable_data() = convert(noValue, operand.data);
96     return protoOperand;
97 }
98 
convert(const std::vector<TestOperand> & operands)99 Operands convert(const std::vector<TestOperand>& operands) {
100     Operands protoOperands;
101     for (const auto& operand : operands) {
102         *protoOperands.add_operand() = convert(operand);
103     }
104     return protoOperands;
105 }
106 
convertIndexes(const std::vector<uint32_t> & indexes)107 Indexes convertIndexes(const std::vector<uint32_t>& indexes) {
108     Indexes protoIndexes;
109     for (uint32_t index : indexes) {
110         protoIndexes.add_index(index);
111     }
112     return protoIndexes;
113 }
114 
convert(const TestOperation & operation)115 Operation convert(const TestOperation& operation) {
116     Operation protoOperation;
117     protoOperation.set_type(convert(operation.type));
118     *protoOperation.mutable_inputs() = convertIndexes(operation.inputs);
119     *protoOperation.mutable_outputs() = convertIndexes(operation.outputs);
120     return protoOperation;
121 }
122 
convert(const std::vector<TestOperation> & operations)123 Operations convert(const std::vector<TestOperation>& operations) {
124     Operations protoOperations;
125     for (const auto& operation : operations) {
126         *protoOperations.add_operation() = convert(operation);
127     }
128     return protoOperations;
129 }
130 
convert(const TestModel & model)131 Model convert(const TestModel& model) {
132     Model protoModel;
133     *protoModel.mutable_operands() = convert(model.operands);
134     *protoModel.mutable_operations() = convert(model.operations);
135     *protoModel.mutable_input_indexes() = convertIndexes(model.inputIndexes);
136     *protoModel.mutable_output_indexes() = convertIndexes(model.outputIndexes);
137     protoModel.set_is_relaxed(model.isRelaxed);
138     return protoModel;
139 }
140 
convertToTest(const TestModel & model)141 Test convertToTest(const TestModel& model) {
142     Test protoTest;
143     *protoTest.mutable_model() = convert(model);
144     return protoTest;
145 }
146 
saveMessageAsText(const google::protobuf::Message & message)147 std::string saveMessageAsText(const google::protobuf::Message& message) {
148     std::string str;
149     if (!google::protobuf::TextFormat::PrintToString(message, &str)) {
150         return {};
151     }
152     return str;
153 }
154 
createCorpusEntry(const std::pair<std::string,const TestModel * > & testCase,const std::string & genDir)155 void createCorpusEntry(const std::pair<std::string, const TestModel*>& testCase,
156                        const std::string& genDir) {
157     const auto& [testName, testModel] = testCase;
158     const Test test = convertToTest(*testModel);
159     const std::string contents = saveMessageAsText(test);
160     const std::string fullName = genDir + "/" + testName;
161     std::ofstream file(fullName.c_str());
162     if (file.good()) {
163         file << contents;
164     }
165 }
166 
167 }  // anonymous namespace
168 }  // namespace android::nn::fuzz
169 
170 using ::android::nn::fuzz::createCorpusEntry;
171 using ::test_helper::TestModel;
172 using ::test_helper::TestModelManager;
173 
main(int argc,char * argv[])174 int main(int argc, char* argv[]) {
175     if (argc != 2) {
176         std::cerr << "error: nnapi_fuzz_generate_corpus requires one argument" << std::endl;
177         return -1;
178     }
179     const std::string genDir = argv[1];
180     const auto filter = [](const TestModel& testModel) { return !testModel.expectFailure; };
181     const auto testModels = TestModelManager::get().getTestModels(filter);
182     std::for_each(testModels.begin(), testModels.end(),
183                   [&genDir](const auto& testCase) { createCorpusEntry(testCase, genDir); });
184     return EXIT_SUCCESS;
185 }
186