1 /*
2  * Copyright 2016 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <binder/SafeInterface.h>
18 
19 #include <binder/IInterface.h>
20 #include <binder/IPCThreadState.h>
21 #include <binder/IServiceManager.h>
22 #include <binder/Parcel.h>
23 #include <binder/Parcelable.h>
24 #include <binder/ProcessState.h>
25 
26 #pragma clang diagnostic push
27 #pragma clang diagnostic ignored "-Weverything"
28 #include <gtest/gtest.h>
29 #pragma clang diagnostic pop
30 
31 #include <utils/LightRefBase.h>
32 #include <utils/NativeHandle.h>
33 
34 #include <cutils/native_handle.h>
35 
36 #include <optional>
37 
38 #include <sys/eventfd.h>
39 #include <sys/prctl.h>
40 
41 using namespace std::chrono_literals; // NOLINT - google-build-using-namespace
42 
43 namespace android {
44 namespace tests {
45 
46 static const String16 kServiceName("SafeInterfaceTest");
47 
48 enum class TestEnum : uint32_t {
49     INVALID = 0,
50     INITIAL = 1,
51     FINAL = 2,
52 };
53 
54 // This class serves two purposes:
55 //   1) It ensures that the implementation doesn't require copying or moving the data (for
56 //      efficiency purposes)
57 //   2) It tests that Parcelables can be passed correctly
58 class NoCopyNoMove : public Parcelable {
59 public:
60     NoCopyNoMove() = default;
NoCopyNoMove(int32_t value)61     explicit NoCopyNoMove(int32_t value) : mValue(value) {}
62     ~NoCopyNoMove() override = default;
63 
64     // Not copyable
65     NoCopyNoMove(const NoCopyNoMove&) = delete;
66     NoCopyNoMove& operator=(const NoCopyNoMove&) = delete;
67 
68     // Not movable
69     NoCopyNoMove(NoCopyNoMove&&) = delete;
70     NoCopyNoMove& operator=(NoCopyNoMove&&) = delete;
71 
72     // Parcelable interface
writeToParcel(Parcel * parcel) const73     status_t writeToParcel(Parcel* parcel) const override { return parcel->writeInt32(mValue); }
readFromParcel(const Parcel * parcel)74     status_t readFromParcel(const Parcel* parcel) override { return parcel->readInt32(&mValue); }
75 
getValue() const76     int32_t getValue() const { return mValue; }
setValue(int32_t value)77     void setValue(int32_t value) { mValue = value; }
78 
79 private:
80     int32_t mValue = 0;
81     __attribute__((unused)) uint8_t mPadding[4] = {}; // Avoids a warning from -Wpadded
82 };
83 
84 struct TestFlattenable : Flattenable<TestFlattenable> {
85     TestFlattenable() = default;
TestFlattenableandroid::tests::TestFlattenable86     explicit TestFlattenable(int32_t v) : value(v) {}
87 
88     // Flattenable protocol
getFlattenedSizeandroid::tests::TestFlattenable89     size_t getFlattenedSize() const { return sizeof(value); }
getFdCountandroid::tests::TestFlattenable90     size_t getFdCount() const { return 0; }
flattenandroid::tests::TestFlattenable91     status_t flatten(void*& buffer, size_t& size, int*& /*fds*/, size_t& /*count*/) const {
92         FlattenableUtils::write(buffer, size, value);
93         return NO_ERROR;
94     }
unflattenandroid::tests::TestFlattenable95     status_t unflatten(void const*& buffer, size_t& size, int const*& /*fds*/, size_t& /*count*/) {
96         FlattenableUtils::read(buffer, size, value);
97         return NO_ERROR;
98     }
99 
100     int32_t value = 0;
101 };
102 
103 struct TestLightFlattenable : LightFlattenablePod<TestLightFlattenable> {
104     TestLightFlattenable() = default;
TestLightFlattenableandroid::tests::TestLightFlattenable105     explicit TestLightFlattenable(int32_t v) : value(v) {}
106     int32_t value = 0;
107 };
108 
109 // It seems like this should be able to inherit from TestFlattenable (to avoid duplicating code),
110 // but the SafeInterface logic can't easily be extended to find an indirect Flattenable<T>
111 // base class
112 class TestLightRefBaseFlattenable : public Flattenable<TestLightRefBaseFlattenable>,
113                                     public LightRefBase<TestLightRefBaseFlattenable> {
114 public:
115     TestLightRefBaseFlattenable() = default;
TestLightRefBaseFlattenable(int32_t v)116     explicit TestLightRefBaseFlattenable(int32_t v) : value(v) {}
117 
118     // Flattenable protocol
getFlattenedSize() const119     size_t getFlattenedSize() const { return sizeof(value); }
getFdCount() const120     size_t getFdCount() const { return 0; }
flatten(void * & buffer,size_t & size,int * &,size_t &) const121     status_t flatten(void*& buffer, size_t& size, int*& /*fds*/, size_t& /*count*/) const {
122         FlattenableUtils::write(buffer, size, value);
123         return NO_ERROR;
124     }
unflatten(void const * & buffer,size_t & size,int const * &,size_t &)125     status_t unflatten(void const*& buffer, size_t& size, int const*& /*fds*/, size_t& /*count*/) {
126         FlattenableUtils::read(buffer, size, value);
127         return NO_ERROR;
128     }
129 
130     int32_t value = 0;
131 };
132 
133 class TestParcelable : public Parcelable {
134 public:
135     TestParcelable() = default;
TestParcelable(int32_t value)136     explicit TestParcelable(int32_t value) : mValue(value) {}
TestParcelable(const TestParcelable & other)137     TestParcelable(const TestParcelable& other) : TestParcelable(other.mValue) {}
TestParcelable(TestParcelable && other)138     TestParcelable(TestParcelable&& other) : TestParcelable(other.mValue) {}
139 
140     // Parcelable interface
writeToParcel(Parcel * parcel) const141     status_t writeToParcel(Parcel* parcel) const override { return parcel->writeInt32(mValue); }
readFromParcel(const Parcel * parcel)142     status_t readFromParcel(const Parcel* parcel) override { return parcel->readInt32(&mValue); }
143 
getValue() const144     int32_t getValue() const { return mValue; }
setValue(int32_t value)145     void setValue(int32_t value) { mValue = value; }
146 
147 private:
148     int32_t mValue = 0;
149 };
150 
151 class ExitOnDeath : public IBinder::DeathRecipient {
152 public:
153     ~ExitOnDeath() override = default;
154 
binderDied(const wp<IBinder> &)155     void binderDied(const wp<IBinder>& /*who*/) override {
156         ALOG(LOG_INFO, "ExitOnDeath", "Exiting");
157         exit(0);
158     }
159 };
160 
161 // This callback class is used to test both one-way transactions and that sp<IInterface> can be
162 // passed correctly
163 class ICallback : public IInterface {
164 public:
165     DECLARE_META_INTERFACE(Callback)
166 
167     enum class Tag : uint32_t {
168         OnCallback = IBinder::FIRST_CALL_TRANSACTION,
169         Last,
170     };
171 
172     virtual void onCallback(int32_t aPlusOne) = 0;
173 };
174 
175 class BpCallback : public SafeBpInterface<ICallback> {
176 public:
BpCallback(const sp<IBinder> & impl)177     explicit BpCallback(const sp<IBinder>& impl) : SafeBpInterface<ICallback>(impl, getLogTag()) {}
178 
onCallback(int32_t aPlusOne)179     void onCallback(int32_t aPlusOne) override {
180         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
181         return callRemoteAsync<decltype(&ICallback::onCallback)>(Tag::OnCallback, aPlusOne);
182     }
183 
184 private:
getLogTag()185     static constexpr const char* getLogTag() { return "BpCallback"; }
186 };
187 
188 #pragma clang diagnostic push
189 #pragma clang diagnostic ignored "-Wexit-time-destructors"
190 IMPLEMENT_META_INTERFACE(Callback, "android.gfx.tests.ICallback")
191 #pragma clang diagnostic pop
192 
193 class BnCallback : public SafeBnInterface<ICallback> {
194 public:
BnCallback()195     BnCallback() : SafeBnInterface("BnCallback") {}
196 
onTransact(uint32_t code,const Parcel & data,Parcel * reply,uint32_t)197     status_t onTransact(uint32_t code, const Parcel& data, Parcel* reply,
198                         uint32_t /*flags*/) override {
199         EXPECT_GE(code, IBinder::FIRST_CALL_TRANSACTION);
200         EXPECT_LT(code, static_cast<uint32_t>(ICallback::Tag::Last));
201         ICallback::Tag tag = static_cast<ICallback::Tag>(code);
202         switch (tag) {
203             case ICallback::Tag::OnCallback: {
204                 return callLocalAsync(data, reply, &ICallback::onCallback);
205             }
206             case ICallback::Tag::Last:
207                 // Should not be possible because of the asserts at the beginning of the method
208                 [&]() { FAIL(); }();
209                 return UNKNOWN_ERROR;
210         }
211     }
212 };
213 
214 class ISafeInterfaceTest : public IInterface {
215 public:
216     DECLARE_META_INTERFACE(SafeInterfaceTest)
217 
218     enum class Tag : uint32_t {
219         SetDeathToken = IBinder::FIRST_CALL_TRANSACTION,
220         ReturnsNoMemory,
221         LogicalNot,
222         ModifyEnum,
223         IncrementFlattenable,
224         IncrementLightFlattenable,
225         IncrementLightRefBaseFlattenable,
226         IncrementNativeHandle,
227         IncrementNoCopyNoMove,
228         IncrementParcelableVector,
229         DoubleString,
230         CallMeBack,
231         IncrementInt32,
232         IncrementUint32,
233         IncrementInt64,
234         IncrementUint64,
235         IncrementFloat,
236         IncrementTwo,
237         Last,
238     };
239 
240     // This is primarily so that the remote service dies when the test does, but it also serves to
241     // test the handling of sp<IBinder> and non-const methods
242     virtual status_t setDeathToken(const sp<IBinder>& token) = 0;
243 
244     // This is the most basic test since it doesn't require parceling any arguments
245     virtual status_t returnsNoMemory() const = 0;
246 
247     // These are ordered according to their corresponding methods in SafeInterface::ParcelHandler
248     virtual status_t logicalNot(bool a, bool* notA) const = 0;
249     virtual status_t modifyEnum(TestEnum a, TestEnum* b) const = 0;
250     virtual status_t increment(const TestFlattenable& a, TestFlattenable* aPlusOne) const = 0;
251     virtual status_t increment(const TestLightFlattenable& a,
252                                TestLightFlattenable* aPlusOne) const = 0;
253     virtual status_t increment(const sp<TestLightRefBaseFlattenable>& a,
254                                sp<TestLightRefBaseFlattenable>* aPlusOne) const = 0;
255     virtual status_t increment(const sp<NativeHandle>& a, sp<NativeHandle>* aPlusOne) const = 0;
256     virtual status_t increment(const NoCopyNoMove& a, NoCopyNoMove* aPlusOne) const = 0;
257     virtual status_t increment(const std::vector<TestParcelable>& a,
258                                std::vector<TestParcelable>* aPlusOne) const = 0;
259     virtual status_t doubleString(const String8& str, String8* doubleStr) const = 0;
260     // As mentioned above, sp<IBinder> is already tested by setDeathToken
261     virtual void callMeBack(const sp<ICallback>& callback, int32_t a) const = 0;
262     virtual status_t increment(int32_t a, int32_t* aPlusOne) const = 0;
263     virtual status_t increment(uint32_t a, uint32_t* aPlusOne) const = 0;
264     virtual status_t increment(int64_t a, int64_t* aPlusOne) const = 0;
265     virtual status_t increment(uint64_t a, uint64_t* aPlusOne) const = 0;
266     virtual status_t increment(float a, float* aPlusOne) const = 0;
267 
268     // This tests that input/output parameter interleaving works correctly
269     virtual status_t increment(int32_t a, int32_t* aPlusOne, int32_t b,
270                                int32_t* bPlusOne) const = 0;
271 };
272 
273 class BpSafeInterfaceTest : public SafeBpInterface<ISafeInterfaceTest> {
274 public:
BpSafeInterfaceTest(const sp<IBinder> & impl)275     explicit BpSafeInterfaceTest(const sp<IBinder>& impl)
276           : SafeBpInterface<ISafeInterfaceTest>(impl, getLogTag()) {}
277 
setDeathToken(const sp<IBinder> & token)278     status_t setDeathToken(const sp<IBinder>& token) override {
279         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
280         return callRemote<decltype(&ISafeInterfaceTest::setDeathToken)>(Tag::SetDeathToken, token);
281     }
returnsNoMemory() const282     status_t returnsNoMemory() const override {
283         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
284         return callRemote<decltype(&ISafeInterfaceTest::returnsNoMemory)>(Tag::ReturnsNoMemory);
285     }
logicalNot(bool a,bool * notA) const286     status_t logicalNot(bool a, bool* notA) const override {
287         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
288         return callRemote<decltype(&ISafeInterfaceTest::logicalNot)>(Tag::LogicalNot, a, notA);
289     }
modifyEnum(TestEnum a,TestEnum * b) const290     status_t modifyEnum(TestEnum a, TestEnum* b) const override {
291         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
292         return callRemote<decltype(&ISafeInterfaceTest::modifyEnum)>(Tag::ModifyEnum, a, b);
293     }
increment(const TestFlattenable & a,TestFlattenable * aPlusOne) const294     status_t increment(const TestFlattenable& a, TestFlattenable* aPlusOne) const override {
295         using Signature =
296                 status_t (ISafeInterfaceTest::*)(const TestFlattenable&, TestFlattenable*) const;
297         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
298         return callRemote<Signature>(Tag::IncrementFlattenable, a, aPlusOne);
299     }
increment(const TestLightFlattenable & a,TestLightFlattenable * aPlusOne) const300     status_t increment(const TestLightFlattenable& a,
301                        TestLightFlattenable* aPlusOne) const override {
302         using Signature = status_t (ISafeInterfaceTest::*)(const TestLightFlattenable&,
303                                                            TestLightFlattenable*) const;
304         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
305         return callRemote<Signature>(Tag::IncrementLightFlattenable, a, aPlusOne);
306     }
increment(const sp<TestLightRefBaseFlattenable> & a,sp<TestLightRefBaseFlattenable> * aPlusOne) const307     status_t increment(const sp<TestLightRefBaseFlattenable>& a,
308                        sp<TestLightRefBaseFlattenable>* aPlusOne) const override {
309         using Signature = status_t (ISafeInterfaceTest::*)(const sp<TestLightRefBaseFlattenable>&,
310                                                            sp<TestLightRefBaseFlattenable>*) const;
311         return callRemote<Signature>(Tag::IncrementLightRefBaseFlattenable, a, aPlusOne);
312     }
increment(const sp<NativeHandle> & a,sp<NativeHandle> * aPlusOne) const313     status_t increment(const sp<NativeHandle>& a, sp<NativeHandle>* aPlusOne) const override {
314         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
315         using Signature =
316                 status_t (ISafeInterfaceTest::*)(const sp<NativeHandle>&, sp<NativeHandle>*) const;
317         return callRemote<Signature>(Tag::IncrementNativeHandle, a, aPlusOne);
318     }
increment(const NoCopyNoMove & a,NoCopyNoMove * aPlusOne) const319     status_t increment(const NoCopyNoMove& a, NoCopyNoMove* aPlusOne) const override {
320         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
321         using Signature = status_t (ISafeInterfaceTest::*)(const NoCopyNoMove& a,
322                                                            NoCopyNoMove* aPlusOne) const;
323         return callRemote<Signature>(Tag::IncrementNoCopyNoMove, a, aPlusOne);
324     }
increment(const std::vector<TestParcelable> & a,std::vector<TestParcelable> * aPlusOne) const325     status_t increment(const std::vector<TestParcelable>& a,
326                        std::vector<TestParcelable>* aPlusOne) const override {
327         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
328         using Signature = status_t (ISafeInterfaceTest::*)(const std::vector<TestParcelable>&,
329                                                            std::vector<TestParcelable>*);
330         return callRemote<Signature>(Tag::IncrementParcelableVector, a, aPlusOne);
331     }
doubleString(const String8 & str,String8 * doubleStr) const332     status_t doubleString(const String8& str, String8* doubleStr) const override {
333         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
334         return callRemote<decltype(&ISafeInterfaceTest::doubleString)>(Tag::DoubleString, str,
335                                                                        doubleStr);
336     }
callMeBack(const sp<ICallback> & callback,int32_t a) const337     void callMeBack(const sp<ICallback>& callback, int32_t a) const override {
338         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
339         return callRemoteAsync<decltype(&ISafeInterfaceTest::callMeBack)>(Tag::CallMeBack, callback,
340                                                                           a);
341     }
increment(int32_t a,int32_t * aPlusOne) const342     status_t increment(int32_t a, int32_t* aPlusOne) const override {
343         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
344         using Signature = status_t (ISafeInterfaceTest::*)(int32_t, int32_t*) const;
345         return callRemote<Signature>(Tag::IncrementInt32, a, aPlusOne);
346     }
increment(uint32_t a,uint32_t * aPlusOne) const347     status_t increment(uint32_t a, uint32_t* aPlusOne) const override {
348         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
349         using Signature = status_t (ISafeInterfaceTest::*)(uint32_t, uint32_t*) const;
350         return callRemote<Signature>(Tag::IncrementUint32, a, aPlusOne);
351     }
increment(int64_t a,int64_t * aPlusOne) const352     status_t increment(int64_t a, int64_t* aPlusOne) const override {
353         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
354         using Signature = status_t (ISafeInterfaceTest::*)(int64_t, int64_t*) const;
355         return callRemote<Signature>(Tag::IncrementInt64, a, aPlusOne);
356     }
increment(uint64_t a,uint64_t * aPlusOne) const357     status_t increment(uint64_t a, uint64_t* aPlusOne) const override {
358         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
359         using Signature = status_t (ISafeInterfaceTest::*)(uint64_t, uint64_t*) const;
360         return callRemote<Signature>(Tag::IncrementUint64, a, aPlusOne);
361     }
increment(float a,float * aPlusOne) const362     status_t increment(float a, float* aPlusOne) const override {
363         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
364         using Signature = status_t (ISafeInterfaceTest::*)(float, float*) const;
365         return callRemote<Signature>(Tag::IncrementFloat, a, aPlusOne);
366     }
increment(int32_t a,int32_t * aPlusOne,int32_t b,int32_t * bPlusOne) const367     status_t increment(int32_t a, int32_t* aPlusOne, int32_t b, int32_t* bPlusOne) const override {
368         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
369         using Signature =
370                 status_t (ISafeInterfaceTest::*)(int32_t, int32_t*, int32_t, int32_t*) const;
371         return callRemote<Signature>(Tag::IncrementTwo, a, aPlusOne, b, bPlusOne);
372     }
373 
374 private:
getLogTag()375     static constexpr const char* getLogTag() { return "BpSafeInterfaceTest"; }
376 };
377 
378 #pragma clang diagnostic push
379 #pragma clang diagnostic ignored "-Wexit-time-destructors"
380 IMPLEMENT_META_INTERFACE(SafeInterfaceTest, "android.gfx.tests.ISafeInterfaceTest")
381 
382 static sp<IBinder::DeathRecipient> getDeathRecipient() {
383     static sp<IBinder::DeathRecipient> recipient = new ExitOnDeath;
384     return recipient;
385 }
386 #pragma clang diagnostic pop
387 
388 class BnSafeInterfaceTest : public SafeBnInterface<ISafeInterfaceTest> {
389 public:
BnSafeInterfaceTest()390     BnSafeInterfaceTest() : SafeBnInterface(getLogTag()) {}
391 
setDeathToken(const sp<IBinder> & token)392     status_t setDeathToken(const sp<IBinder>& token) override {
393         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
394         token->linkToDeath(getDeathRecipient());
395         return NO_ERROR;
396     }
returnsNoMemory() const397     status_t returnsNoMemory() const override {
398         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
399         return NO_MEMORY;
400     }
logicalNot(bool a,bool * notA) const401     status_t logicalNot(bool a, bool* notA) const override {
402         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
403         *notA = !a;
404         return NO_ERROR;
405     }
modifyEnum(TestEnum a,TestEnum * b) const406     status_t modifyEnum(TestEnum a, TestEnum* b) const override {
407         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
408         *b = (a == TestEnum::INITIAL) ? TestEnum::FINAL : TestEnum::INVALID;
409         return NO_ERROR;
410     }
increment(const TestFlattenable & a,TestFlattenable * aPlusOne) const411     status_t increment(const TestFlattenable& a, TestFlattenable* aPlusOne) const override {
412         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
413         aPlusOne->value = a.value + 1;
414         return NO_ERROR;
415     }
increment(const TestLightFlattenable & a,TestLightFlattenable * aPlusOne) const416     status_t increment(const TestLightFlattenable& a,
417                        TestLightFlattenable* aPlusOne) const override {
418         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
419         aPlusOne->value = a.value + 1;
420         return NO_ERROR;
421     }
increment(const sp<TestLightRefBaseFlattenable> & a,sp<TestLightRefBaseFlattenable> * aPlusOne) const422     status_t increment(const sp<TestLightRefBaseFlattenable>& a,
423                        sp<TestLightRefBaseFlattenable>* aPlusOne) const override {
424         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
425         *aPlusOne = new TestLightRefBaseFlattenable(a->value + 1);
426         return NO_ERROR;
427     }
increment(const sp<NativeHandle> & a,sp<NativeHandle> * aPlusOne) const428     status_t increment(const sp<NativeHandle>& a, sp<NativeHandle>* aPlusOne) const override {
429         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
430         native_handle* rawHandle = native_handle_create(1 /*numFds*/, 1 /*numInts*/);
431         if (rawHandle == nullptr) return NO_MEMORY;
432 
433         // Copy the fd over directly
434         rawHandle->data[0] = dup(a->handle()->data[0]);
435 
436         // Increment the int
437         rawHandle->data[1] = a->handle()->data[1] + 1;
438 
439         // This cannot fail, as it is just the sp<NativeHandle> taking responsibility for closing
440         // the native_handle when it goes out of scope
441         *aPlusOne = NativeHandle::create(rawHandle, true);
442         return NO_ERROR;
443     }
increment(const NoCopyNoMove & a,NoCopyNoMove * aPlusOne) const444     status_t increment(const NoCopyNoMove& a, NoCopyNoMove* aPlusOne) const override {
445         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
446         aPlusOne->setValue(a.getValue() + 1);
447         return NO_ERROR;
448     }
increment(const std::vector<TestParcelable> & a,std::vector<TestParcelable> * aPlusOne) const449     status_t increment(const std::vector<TestParcelable>& a,
450                        std::vector<TestParcelable>* aPlusOne) const override {
451         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
452         aPlusOne->resize(a.size());
453         for (size_t i = 0; i < a.size(); ++i) {
454             (*aPlusOne)[i].setValue(a[i].getValue() + 1);
455         }
456         return NO_ERROR;
457     }
doubleString(const String8 & str,String8 * doubleStr) const458     status_t doubleString(const String8& str, String8* doubleStr) const override {
459         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
460         *doubleStr = str + str;
461         return NO_ERROR;
462     }
callMeBack(const sp<ICallback> & callback,int32_t a) const463     void callMeBack(const sp<ICallback>& callback, int32_t a) const override {
464         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
465         callback->onCallback(a + 1);
466     }
increment(int32_t a,int32_t * aPlusOne) const467     status_t increment(int32_t a, int32_t* aPlusOne) const override {
468         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
469         *aPlusOne = a + 1;
470         return NO_ERROR;
471     }
increment(uint32_t a,uint32_t * aPlusOne) const472     status_t increment(uint32_t a, uint32_t* aPlusOne) const override {
473         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
474         *aPlusOne = a + 1;
475         return NO_ERROR;
476     }
increment(int64_t a,int64_t * aPlusOne) const477     status_t increment(int64_t a, int64_t* aPlusOne) const override {
478         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
479         *aPlusOne = a + 1;
480         return NO_ERROR;
481     }
increment(uint64_t a,uint64_t * aPlusOne) const482     status_t increment(uint64_t a, uint64_t* aPlusOne) const override {
483         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
484         *aPlusOne = a + 1;
485         return NO_ERROR;
486     }
increment(float a,float * aPlusOne) const487     status_t increment(float a, float* aPlusOne) const override {
488         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
489         *aPlusOne = a + 1.0f;
490         return NO_ERROR;
491     }
increment(int32_t a,int32_t * aPlusOne,int32_t b,int32_t * bPlusOne) const492     status_t increment(int32_t a, int32_t* aPlusOne, int32_t b, int32_t* bPlusOne) const override {
493         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
494         *aPlusOne = a + 1;
495         *bPlusOne = b + 1;
496         return NO_ERROR;
497     }
498 
499     // BnInterface
onTransact(uint32_t code,const Parcel & data,Parcel * reply,uint32_t)500     status_t onTransact(uint32_t code, const Parcel& data, Parcel* reply,
501                         uint32_t /*flags*/) override {
502         EXPECT_GE(code, IBinder::FIRST_CALL_TRANSACTION);
503         EXPECT_LT(code, static_cast<uint32_t>(Tag::Last));
504         ISafeInterfaceTest::Tag tag = static_cast<ISafeInterfaceTest::Tag>(code);
505         switch (tag) {
506             case ISafeInterfaceTest::Tag::SetDeathToken: {
507                 return callLocal(data, reply, &ISafeInterfaceTest::setDeathToken);
508             }
509             case ISafeInterfaceTest::Tag::ReturnsNoMemory: {
510                 return callLocal(data, reply, &ISafeInterfaceTest::returnsNoMemory);
511             }
512             case ISafeInterfaceTest::Tag::LogicalNot: {
513                 return callLocal(data, reply, &ISafeInterfaceTest::logicalNot);
514             }
515             case ISafeInterfaceTest::Tag::ModifyEnum: {
516                 return callLocal(data, reply, &ISafeInterfaceTest::modifyEnum);
517             }
518             case ISafeInterfaceTest::Tag::IncrementFlattenable: {
519                 using Signature = status_t (ISafeInterfaceTest::*)(const TestFlattenable& a,
520                                                                    TestFlattenable* aPlusOne) const;
521                 return callLocal<Signature>(data, reply, &ISafeInterfaceTest::increment);
522             }
523             case ISafeInterfaceTest::Tag::IncrementLightFlattenable: {
524                 using Signature =
525                         status_t (ISafeInterfaceTest::*)(const TestLightFlattenable& a,
526                                                          TestLightFlattenable* aPlusOne) const;
527                 return callLocal<Signature>(data, reply, &ISafeInterfaceTest::increment);
528             }
529             case ISafeInterfaceTest::Tag::IncrementLightRefBaseFlattenable: {
530                 using Signature =
531                         status_t (ISafeInterfaceTest::*)(const sp<TestLightRefBaseFlattenable>&,
532                                                          sp<TestLightRefBaseFlattenable>*) const;
533                 return callLocal<Signature>(data, reply, &ISafeInterfaceTest::increment);
534             }
535             case ISafeInterfaceTest::Tag::IncrementNativeHandle: {
536                 using Signature = status_t (ISafeInterfaceTest::*)(const sp<NativeHandle>&,
537                                                                    sp<NativeHandle>*) const;
538                 return callLocal<Signature>(data, reply, &ISafeInterfaceTest::increment);
539             }
540             case ISafeInterfaceTest::Tag::IncrementNoCopyNoMove: {
541                 using Signature = status_t (ISafeInterfaceTest::*)(const NoCopyNoMove& a,
542                                                                    NoCopyNoMove* aPlusOne) const;
543                 return callLocal<Signature>(data, reply, &ISafeInterfaceTest::increment);
544             }
545             case ISafeInterfaceTest::Tag::IncrementParcelableVector: {
546                 using Signature =
547                         status_t (ISafeInterfaceTest::*)(const std::vector<TestParcelable>&,
548                                                          std::vector<TestParcelable>*) const;
549                 return callLocal<Signature>(data, reply, &ISafeInterfaceTest::increment);
550             }
551             case ISafeInterfaceTest::Tag::DoubleString: {
552                 return callLocal(data, reply, &ISafeInterfaceTest::doubleString);
553             }
554             case ISafeInterfaceTest::Tag::CallMeBack: {
555                 return callLocalAsync(data, reply, &ISafeInterfaceTest::callMeBack);
556             }
557             case ISafeInterfaceTest::Tag::IncrementInt32: {
558                 using Signature = status_t (ISafeInterfaceTest::*)(int32_t, int32_t*) const;
559                 return callLocal<Signature>(data, reply, &ISafeInterfaceTest::increment);
560             }
561             case ISafeInterfaceTest::Tag::IncrementUint32: {
562                 using Signature = status_t (ISafeInterfaceTest::*)(uint32_t, uint32_t*) const;
563                 return callLocal<Signature>(data, reply, &ISafeInterfaceTest::increment);
564             }
565             case ISafeInterfaceTest::Tag::IncrementInt64: {
566                 using Signature = status_t (ISafeInterfaceTest::*)(int64_t, int64_t*) const;
567                 return callLocal<Signature>(data, reply, &ISafeInterfaceTest::increment);
568             }
569             case ISafeInterfaceTest::Tag::IncrementUint64: {
570                 using Signature = status_t (ISafeInterfaceTest::*)(uint64_t, uint64_t*) const;
571                 return callLocal<Signature>(data, reply, &ISafeInterfaceTest::increment);
572             }
573             case ISafeInterfaceTest::Tag::IncrementFloat: {
574                 using Signature = status_t (ISafeInterfaceTest::*)(float, float*) const;
575                 return callLocal<Signature>(data, reply, &ISafeInterfaceTest::increment);
576             }
577             case ISafeInterfaceTest::Tag::IncrementTwo: {
578                 using Signature = status_t (ISafeInterfaceTest::*)(int32_t, int32_t*, int32_t,
579                                                                    int32_t*) const;
580                 return callLocal<Signature>(data, reply, &ISafeInterfaceTest::increment);
581             }
582             case ISafeInterfaceTest::Tag::Last:
583                 // Should not be possible because of the asserts at the beginning of the method
584                 [&]() { FAIL(); }();
585                 return UNKNOWN_ERROR;
586         }
587     }
588 
589 private:
getLogTag()590     static constexpr const char* getLogTag() { return "BnSafeInterfaceTest"; }
591 };
592 
593 class SafeInterfaceTest : public ::testing::Test {
594 public:
SafeInterfaceTest()595     SafeInterfaceTest() : mSafeInterfaceTest(getRemoteService()) {
596         ProcessState::self()->startThreadPool();
597     }
598     ~SafeInterfaceTest() override = default;
599 
600 protected:
601     sp<ISafeInterfaceTest> mSafeInterfaceTest;
602 
603 private:
getLogTag()604     static constexpr const char* getLogTag() { return "SafeInterfaceTest"; }
605 
getRemoteService()606     sp<ISafeInterfaceTest> getRemoteService() {
607         sp<IBinder> binder = defaultServiceManager()->getService(kServiceName);
608         sp<ISafeInterfaceTest> iface = interface_cast<ISafeInterfaceTest>(binder);
609         EXPECT_TRUE(iface != nullptr);
610 
611         iface->setDeathToken(new BBinder);
612 
613         return iface;
614     }
615 };
616 
TEST_F(SafeInterfaceTest,TestReturnsNoMemory)617 TEST_F(SafeInterfaceTest, TestReturnsNoMemory) {
618     status_t result = mSafeInterfaceTest->returnsNoMemory();
619     ASSERT_EQ(NO_MEMORY, result);
620 }
621 
TEST_F(SafeInterfaceTest,TestLogicalNot)622 TEST_F(SafeInterfaceTest, TestLogicalNot) {
623     const bool a = true;
624     bool notA = true;
625     status_t result = mSafeInterfaceTest->logicalNot(a, &notA);
626     ASSERT_EQ(NO_ERROR, result);
627     ASSERT_EQ(!a, notA);
628     // Test both since we don't want to accidentally catch a default false somewhere
629     const bool b = false;
630     bool notB = false;
631     result = mSafeInterfaceTest->logicalNot(b, &notB);
632     ASSERT_EQ(NO_ERROR, result);
633     ASSERT_EQ(!b, notB);
634 }
635 
TEST_F(SafeInterfaceTest,TestModifyEnum)636 TEST_F(SafeInterfaceTest, TestModifyEnum) {
637     const TestEnum a = TestEnum::INITIAL;
638     TestEnum b = TestEnum::INVALID;
639     status_t result = mSafeInterfaceTest->modifyEnum(a, &b);
640     ASSERT_EQ(NO_ERROR, result);
641     ASSERT_EQ(TestEnum::FINAL, b);
642 }
643 
TEST_F(SafeInterfaceTest,TestIncrementFlattenable)644 TEST_F(SafeInterfaceTest, TestIncrementFlattenable) {
645     const TestFlattenable a{1};
646     TestFlattenable aPlusOne{0};
647     status_t result = mSafeInterfaceTest->increment(a, &aPlusOne);
648     ASSERT_EQ(NO_ERROR, result);
649     ASSERT_EQ(a.value + 1, aPlusOne.value);
650 }
651 
TEST_F(SafeInterfaceTest,TestIncrementLightFlattenable)652 TEST_F(SafeInterfaceTest, TestIncrementLightFlattenable) {
653     const TestLightFlattenable a{1};
654     TestLightFlattenable aPlusOne{0};
655     status_t result = mSafeInterfaceTest->increment(a, &aPlusOne);
656     ASSERT_EQ(NO_ERROR, result);
657     ASSERT_EQ(a.value + 1, aPlusOne.value);
658 }
659 
TEST_F(SafeInterfaceTest,TestIncrementLightRefBaseFlattenable)660 TEST_F(SafeInterfaceTest, TestIncrementLightRefBaseFlattenable) {
661     sp<TestLightRefBaseFlattenable> a = new TestLightRefBaseFlattenable{1};
662     sp<TestLightRefBaseFlattenable> aPlusOne;
663     status_t result = mSafeInterfaceTest->increment(a, &aPlusOne);
664     ASSERT_EQ(NO_ERROR, result);
665     ASSERT_NE(nullptr, aPlusOne.get());
666     ASSERT_EQ(a->value + 1, aPlusOne->value);
667 }
668 
669 namespace { // Anonymous namespace
670 
fdsAreEquivalent(int a,int b)671 bool fdsAreEquivalent(int a, int b) {
672     struct stat statA {};
673     struct stat statB {};
674     if (fstat(a, &statA) != 0) return false;
675     if (fstat(b, &statB) != 0) return false;
676     return (statA.st_dev == statB.st_dev) && (statA.st_ino == statB.st_ino);
677 }
678 
679 } // Anonymous namespace
680 
TEST_F(SafeInterfaceTest,TestIncrementNativeHandle)681 TEST_F(SafeInterfaceTest, TestIncrementNativeHandle) {
682     // Create an fd we can use to send and receive from the remote process
683     base::unique_fd eventFd{eventfd(0 /*initval*/, 0 /*flags*/)};
684     ASSERT_NE(-1, eventFd);
685 
686     // Determine the maximum number of fds this process can have open
687     struct rlimit limit {};
688     ASSERT_EQ(0, getrlimit(RLIMIT_NOFILE, &limit));
689     uint32_t maxFds = static_cast<uint32_t>(limit.rlim_cur);
690 
691     // Perform this test enough times to rule out fd leaks
692     for (uint32_t iter = 0; iter < (2 * maxFds); ++iter) {
693         native_handle* handle = native_handle_create(1 /*numFds*/, 1 /*numInts*/);
694         ASSERT_NE(nullptr, handle);
695         handle->data[0] = dup(eventFd.get());
696         handle->data[1] = 1;
697 
698         // This cannot fail, as it is just the sp<NativeHandle> taking responsibility for closing
699         // the native_handle when it goes out of scope
700         sp<NativeHandle> a = NativeHandle::create(handle, true);
701 
702         sp<NativeHandle> aPlusOne;
703         status_t result = mSafeInterfaceTest->increment(a, &aPlusOne);
704         ASSERT_EQ(NO_ERROR, result);
705         ASSERT_TRUE(fdsAreEquivalent(a->handle()->data[0], aPlusOne->handle()->data[0]));
706         ASSERT_EQ(a->handle()->data[1] + 1, aPlusOne->handle()->data[1]);
707     }
708 }
709 
TEST_F(SafeInterfaceTest,TestIncrementNoCopyNoMove)710 TEST_F(SafeInterfaceTest, TestIncrementNoCopyNoMove) {
711     const NoCopyNoMove a{1};
712     NoCopyNoMove aPlusOne{0};
713     status_t result = mSafeInterfaceTest->increment(a, &aPlusOne);
714     ASSERT_EQ(NO_ERROR, result);
715     ASSERT_EQ(a.getValue() + 1, aPlusOne.getValue());
716 }
717 
TEST_F(SafeInterfaceTest,TestIncremementParcelableVector)718 TEST_F(SafeInterfaceTest, TestIncremementParcelableVector) {
719     const std::vector<TestParcelable> a{TestParcelable{1}, TestParcelable{2}};
720     std::vector<TestParcelable> aPlusOne;
721     status_t result = mSafeInterfaceTest->increment(a, &aPlusOne);
722     ASSERT_EQ(NO_ERROR, result);
723     ASSERT_EQ(a.size(), aPlusOne.size());
724     for (size_t i = 0; i < a.size(); ++i) {
725         ASSERT_EQ(a[i].getValue() + 1, aPlusOne[i].getValue());
726     }
727 }
728 
TEST_F(SafeInterfaceTest,TestDoubleString)729 TEST_F(SafeInterfaceTest, TestDoubleString) {
730     const String8 str{"asdf"};
731     String8 doubleStr;
732     status_t result = mSafeInterfaceTest->doubleString(str, &doubleStr);
733     ASSERT_EQ(NO_ERROR, result);
734     ASSERT_TRUE(doubleStr == String8{"asdfasdf"});
735 }
736 
TEST_F(SafeInterfaceTest,TestCallMeBack)737 TEST_F(SafeInterfaceTest, TestCallMeBack) {
738     class CallbackReceiver : public BnCallback {
739     public:
740         void onCallback(int32_t aPlusOne) override {
741             ALOG(LOG_INFO, "CallbackReceiver", "%s", __PRETTY_FUNCTION__);
742             std::unique_lock<decltype(mMutex)> lock(mMutex);
743             mValue = aPlusOne;
744             mCondition.notify_one();
745         }
746 
747         std::optional<int32_t> waitForCallback() {
748             std::unique_lock<decltype(mMutex)> lock(mMutex);
749             bool success =
750                     mCondition.wait_for(lock, 100ms, [&]() { return static_cast<bool>(mValue); });
751             return success ? mValue : std::nullopt;
752         }
753 
754     private:
755         std::mutex mMutex;
756         std::condition_variable mCondition;
757         std::optional<int32_t> mValue;
758     };
759 
760     sp<CallbackReceiver> receiver = new CallbackReceiver;
761     const int32_t a = 1;
762     mSafeInterfaceTest->callMeBack(receiver, a);
763     auto result = receiver->waitForCallback();
764     ASSERT_TRUE(result);
765     ASSERT_EQ(a + 1, *result);
766 }
767 
TEST_F(SafeInterfaceTest,TestIncrementInt32)768 TEST_F(SafeInterfaceTest, TestIncrementInt32) {
769     const int32_t a = 1;
770     int32_t aPlusOne = 0;
771     status_t result = mSafeInterfaceTest->increment(a, &aPlusOne);
772     ASSERT_EQ(NO_ERROR, result);
773     ASSERT_EQ(a + 1, aPlusOne);
774 }
775 
TEST_F(SafeInterfaceTest,TestIncrementUint32)776 TEST_F(SafeInterfaceTest, TestIncrementUint32) {
777     const uint32_t a = 1;
778     uint32_t aPlusOne = 0;
779     status_t result = mSafeInterfaceTest->increment(a, &aPlusOne);
780     ASSERT_EQ(NO_ERROR, result);
781     ASSERT_EQ(a + 1, aPlusOne);
782 }
783 
TEST_F(SafeInterfaceTest,TestIncrementInt64)784 TEST_F(SafeInterfaceTest, TestIncrementInt64) {
785     const int64_t a = 1;
786     int64_t aPlusOne = 0;
787     status_t result = mSafeInterfaceTest->increment(a, &aPlusOne);
788     ASSERT_EQ(NO_ERROR, result);
789     ASSERT_EQ(a + 1, aPlusOne);
790 }
791 
TEST_F(SafeInterfaceTest,TestIncrementUint64)792 TEST_F(SafeInterfaceTest, TestIncrementUint64) {
793     const uint64_t a = 1;
794     uint64_t aPlusOne = 0;
795     status_t result = mSafeInterfaceTest->increment(a, &aPlusOne);
796     ASSERT_EQ(NO_ERROR, result);
797     ASSERT_EQ(a + 1, aPlusOne);
798 }
799 
TEST_F(SafeInterfaceTest,TestIncrementFloat)800 TEST_F(SafeInterfaceTest, TestIncrementFloat) {
801     const float a = 1.0f;
802     float aPlusOne = 0.0f;
803     status_t result = mSafeInterfaceTest->increment(a, &aPlusOne);
804     ASSERT_EQ(NO_ERROR, result);
805     ASSERT_EQ(a + 1.0f, aPlusOne);
806 }
807 
TEST_F(SafeInterfaceTest,TestIncrementTwo)808 TEST_F(SafeInterfaceTest, TestIncrementTwo) {
809     const int32_t a = 1;
810     int32_t aPlusOne = 0;
811     const int32_t b = 2;
812     int32_t bPlusOne = 0;
813     status_t result = mSafeInterfaceTest->increment(1, &aPlusOne, 2, &bPlusOne);
814     ASSERT_EQ(NO_ERROR, result);
815     ASSERT_EQ(a + 1, aPlusOne);
816     ASSERT_EQ(b + 1, bPlusOne);
817 }
818 
main(int argc,char ** argv)819 extern "C" int main(int argc, char **argv) {
820     testing::InitGoogleTest(&argc, argv);
821 
822     if (fork() == 0) {
823         prctl(PR_SET_PDEATHSIG, SIGHUP);
824         sp<BnSafeInterfaceTest> nativeService = new BnSafeInterfaceTest;
825         status_t status = defaultServiceManager()->addService(kServiceName, nativeService);
826         if (status != OK) {
827             ALOG(LOG_INFO, "SafeInterfaceServer", "could not register");
828             return EXIT_FAILURE;
829         }
830         IPCThreadState::self()->joinThreadPool();
831         return EXIT_FAILURE;
832     }
833 
834     return RUN_ALL_TESTS();
835 }
836 
837 } // namespace tests
838 } // namespace android
839