1 /*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <google/protobuf/text_format.h>
18
19 #include <algorithm>
20 #include <fstream>
21 #include <iostream>
22 #include <random>
23 #include <string>
24 #include <string_view>
25 #include <utility>
26 #include <vector>
27
28 #include "Converter.h"
29
30 namespace android::nn::fuzz {
31 namespace {
32
33 using namespace test_helper;
34 using namespace android_nn_fuzz;
35
convert(TestOperandType type)36 OperandType convert(TestOperandType type) {
37 return static_cast<OperandType>(type);
38 }
39
convert(TestOperationType type)40 OperationType convert(TestOperationType type) {
41 return static_cast<OperationType>(type);
42 }
43
convert(TestOperandLifeTime lifetime)44 Operand::LifeTime convert(TestOperandLifeTime lifetime) {
45 return static_cast<Operand::LifeTime>(lifetime);
46 }
47
convert(const std::vector<float> & scales)48 Scales convert(const std::vector<float>& scales) {
49 Scales protoScales;
50 for (float scale : scales) {
51 protoScales.add_scale(scale);
52 }
53 return protoScales;
54 }
55
convert(const TestSymmPerChannelQuantParams & params)56 SymmPerChannelQuantParams convert(const TestSymmPerChannelQuantParams& params) {
57 SymmPerChannelQuantParams protoParams;
58 *protoParams.mutable_scales() = convert(params.scales);
59 protoParams.set_channel_dim(params.channelDim);
60 return protoParams;
61 }
62
convertDimensions(const std::vector<uint32_t> & dimensions)63 Dimensions convertDimensions(const std::vector<uint32_t>& dimensions) {
64 Dimensions protoDimensions;
65 for (uint32_t dimension : dimensions) {
66 protoDimensions.add_dimension(dimension);
67 }
68 return protoDimensions;
69 }
70
getHashValue(const TestBuffer & buffer)71 uint32_t getHashValue(const TestBuffer& buffer) {
72 const char* ptr = buffer.get<char>();
73 const size_t size = buffer.size();
74 const std::string_view view(ptr, size);
75 const size_t value = std::hash<std::string_view>{}(view);
76 return static_cast<uint32_t>(value & 0xFFFFFFFF);
77 }
78
convert(bool noValue,const TestBuffer & buffer)79 Buffer convert(bool noValue, const TestBuffer& buffer) {
80 Buffer protoBuffer;
81 const uint32_t randomSeed = (noValue ? 0 : getHashValue(buffer));
82 protoBuffer.set_random_seed(randomSeed);
83 return protoBuffer;
84 }
85
convert(const TestOperand & operand)86 Operand convert(const TestOperand& operand) {
87 Operand protoOperand;
88 protoOperand.set_type(convert(operand.type));
89 *protoOperand.mutable_dimensions() = convertDimensions(operand.dimensions);
90 protoOperand.set_scale(operand.scale);
91 protoOperand.set_zero_point(operand.zeroPoint);
92 protoOperand.set_lifetime(convert(operand.lifetime));
93 *protoOperand.mutable_channel_quant() = convert(operand.channelQuant);
94 const bool noValue = (operand.lifetime == TestOperandLifeTime::NO_VALUE);
95 *protoOperand.mutable_data() = convert(noValue, operand.data);
96 return protoOperand;
97 }
98
convert(const std::vector<TestOperand> & operands)99 Operands convert(const std::vector<TestOperand>& operands) {
100 Operands protoOperands;
101 for (const auto& operand : operands) {
102 *protoOperands.add_operand() = convert(operand);
103 }
104 return protoOperands;
105 }
106
convertIndexes(const std::vector<uint32_t> & indexes)107 Indexes convertIndexes(const std::vector<uint32_t>& indexes) {
108 Indexes protoIndexes;
109 for (uint32_t index : indexes) {
110 protoIndexes.add_index(index);
111 }
112 return protoIndexes;
113 }
114
convert(const TestOperation & operation)115 Operation convert(const TestOperation& operation) {
116 Operation protoOperation;
117 protoOperation.set_type(convert(operation.type));
118 *protoOperation.mutable_inputs() = convertIndexes(operation.inputs);
119 *protoOperation.mutable_outputs() = convertIndexes(operation.outputs);
120 return protoOperation;
121 }
122
convert(const std::vector<TestOperation> & operations)123 Operations convert(const std::vector<TestOperation>& operations) {
124 Operations protoOperations;
125 for (const auto& operation : operations) {
126 *protoOperations.add_operation() = convert(operation);
127 }
128 return protoOperations;
129 }
130
convert(const TestModel & model)131 Model convert(const TestModel& model) {
132 Model protoModel;
133 *protoModel.mutable_operands() = convert(model.operands);
134 *protoModel.mutable_operations() = convert(model.operations);
135 *protoModel.mutable_input_indexes() = convertIndexes(model.inputIndexes);
136 *protoModel.mutable_output_indexes() = convertIndexes(model.outputIndexes);
137 protoModel.set_is_relaxed(model.isRelaxed);
138 return protoModel;
139 }
140
convertToTest(const TestModel & model)141 Test convertToTest(const TestModel& model) {
142 Test protoTest;
143 *protoTest.mutable_model() = convert(model);
144 return protoTest;
145 }
146
saveMessageAsText(const google::protobuf::Message & message)147 std::string saveMessageAsText(const google::protobuf::Message& message) {
148 std::string str;
149 if (!google::protobuf::TextFormat::PrintToString(message, &str)) {
150 return {};
151 }
152 return str;
153 }
154
createCorpusEntry(const std::pair<std::string,const TestModel * > & testCase,const std::string & genDir)155 void createCorpusEntry(const std::pair<std::string, const TestModel*>& testCase,
156 const std::string& genDir) {
157 const auto& [testName, testModel] = testCase;
158 const Test test = convertToTest(*testModel);
159 const std::string contents = saveMessageAsText(test);
160 const std::string fullName = genDir + "/" + testName;
161 std::ofstream file(fullName.c_str());
162 if (file.good()) {
163 file << contents;
164 }
165 }
166
167 } // anonymous namespace
168 } // namespace android::nn::fuzz
169
170 using ::android::nn::fuzz::createCorpusEntry;
171 using ::test_helper::TestModel;
172 using ::test_helper::TestModelManager;
173
main(int argc,char * argv[])174 int main(int argc, char* argv[]) {
175 if (argc != 2) {
176 std::cerr << "error: nnapi_fuzz_generate_corpus requires one argument" << std::endl;
177 return -1;
178 }
179 const std::string genDir = argv[1];
180 const auto filter = [](const TestModel& testModel) { return !testModel.expectFailure; };
181 const auto testModels = TestModelManager::get().getTestModels(filter);
182 std::for_each(testModels.begin(), testModels.end(),
183 [&genDir](const auto& testCase) { createCorpusEntry(testCase, genDir); });
184 return EXIT_SUCCESS;
185 }
186