1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.protolog.tool
18 
19 import com.github.javaparser.StaticJavaParser
20 import com.github.javaparser.ast.CompilationUnit
21 import com.github.javaparser.ast.expr.MethodCallExpr
22 import com.github.javaparser.ast.stmt.IfStmt
23 import org.junit.Assert.assertEquals
24 import org.junit.Assert.assertFalse
25 import org.junit.Test
26 import org.mockito.Mockito
27 
28 class SourceTransformerTest {
29     companion object {
30         private const val PROTO_LOG_IMPL_PATH = "org.example.ProtoLogImpl"
31 
32         /* ktlint-disable max-line-length */
33         private val TEST_CODE = """
34             package org.example;
35 
36             class Test {
37                 void test() {
38                     ProtoLog.w(TEST_GROUP, "test %d %f", 100, 0.1);
39                 }
40             }
41             """.trimIndent()
42 
43         private val TEST_CODE_MULTILINE = """
44             package org.example;
45 
46             class Test {
47                 void test() {
48                     ProtoLog.w(TEST_GROUP, "test %d %f " +
49                     "abc %s\n test", 100,
50                      0.1, "test");
51                 }
52             }
53             """.trimIndent()
54 
55         private val TEST_CODE_MULTICALLS = """
56             package org.example;
57 
58             class Test {
59                 void test() {
60                     ProtoLog.w(TEST_GROUP, "test %d %f", 100, 0.1); /* ProtoLog.w(TEST_GROUP, "test %d %f", 100, 0.1); */ ProtoLog.w(TEST_GROUP, "test %d %f", 100, 0.1);
61                     ProtoLog.w(TEST_GROUP, "test %d %f", 100, 0.1);
62                 }
63             }
64             """.trimIndent()
65 
66         private val TEST_CODE_NO_PARAMS = """
67             package org.example;
68 
69             class Test {
70                 void test() {
71                     ProtoLog.w(TEST_GROUP, "test");
72                 }
73             }
74             """.trimIndent()
75 
76         private val TRANSFORMED_CODE_TEXT_ENABLED = """
77             package org.example;
78 
79             class Test {
80                 void test() {
81                     if (org.example.ProtoLogCache.TEST_GROUP_enabled) { long protoLogParam0 = 100; double protoLogParam1 = 0.1; org.example.ProtoLogImpl.w(TEST_GROUP, 1698911065, 9, "test %d %f", protoLogParam0, protoLogParam1); }
82                 }
83             }
84             """.trimIndent()
85 
86         private val TRANSFORMED_CODE_MULTILINE_TEXT_ENABLED = """
87             package org.example;
88 
89             class Test {
90                 void test() {
91                     if (org.example.ProtoLogCache.TEST_GROUP_enabled) { long protoLogParam0 = 100; double protoLogParam1 = 0.1; String protoLogParam2 = String.valueOf("test"); org.example.ProtoLogImpl.w(TEST_GROUP, 1780316587, 9, "test %d %f " + "abc %s\n test", protoLogParam0, protoLogParam1, protoLogParam2);
92 
93             }
94                 }
95             }
96             """.trimIndent()
97 
98         private val TRANSFORMED_CODE_MULTICALL_TEXT_ENABLED = """
99             package org.example;
100 
101             class Test {
102                 void test() {
103                     if (org.example.ProtoLogCache.TEST_GROUP_enabled) { long protoLogParam0 = 100; double protoLogParam1 = 0.1; org.example.ProtoLogImpl.w(TEST_GROUP, 1698911065, 9, "test %d %f", protoLogParam0, protoLogParam1); } /* ProtoLog.w(TEST_GROUP, "test %d %f", 100, 0.1); */ if (org.example.ProtoLogCache.TEST_GROUP_enabled) { long protoLogParam0 = 100; double protoLogParam1 = 0.1; org.example.ProtoLogImpl.w(TEST_GROUP, 1698911065, 9, "test %d %f", protoLogParam0, protoLogParam1); }
104                     if (org.example.ProtoLogCache.TEST_GROUP_enabled) { long protoLogParam0 = 100; double protoLogParam1 = 0.1; org.example.ProtoLogImpl.w(TEST_GROUP, 1698911065, 9, "test %d %f", protoLogParam0, protoLogParam1); }
105                 }
106             }
107             """.trimIndent()
108 
109         private val TRANSFORMED_CODE_NO_PARAMS = """
110             package org.example;
111 
112             class Test {
113                 void test() {
114                     if (org.example.ProtoLogCache.TEST_GROUP_enabled) { org.example.ProtoLogImpl.w(TEST_GROUP, -1741986185, 0, "test", (Object[]) null); }
115                 }
116             }
117             """.trimIndent()
118 
119         private val TRANSFORMED_CODE_TEXT_DISABLED = """
120             package org.example;
121 
122             class Test {
123                 void test() {
124                     if (org.example.ProtoLogCache.TEST_GROUP_enabled) { long protoLogParam0 = 100; double protoLogParam1 = 0.1; org.example.ProtoLogImpl.w(TEST_GROUP, 1698911065, 9, null, protoLogParam0, protoLogParam1); }
125                 }
126             }
127             """.trimIndent()
128 
129         private val TRANSFORMED_CODE_MULTILINE_TEXT_DISABLED = """
130             package org.example;
131 
132             class Test {
133                 void test() {
134                     if (org.example.ProtoLogCache.TEST_GROUP_enabled) { long protoLogParam0 = 100; double protoLogParam1 = 0.1; String protoLogParam2 = String.valueOf("test"); org.example.ProtoLogImpl.w(TEST_GROUP, 1780316587, 9, null, protoLogParam0, protoLogParam1, protoLogParam2);
135 
136             }
137                 }
138             }
139             """.trimIndent()
140 
141         private val TRANSFORMED_CODE_DISABLED = """
142             package org.example;
143 
144             class Test {
145                 void test() {
146                     if (false) { /* TEST_GROUP is disabled */ ProtoLog.w(TEST_GROUP, "test %d %f", 100, 0.1); }
147                 }
148             }
149             """.trimIndent()
150 
151         private val TRANSFORMED_CODE_MULTILINE_DISABLED = """
152             package org.example;
153 
154             class Test {
155                 void test() {
156                     if (false) { /* TEST_GROUP is disabled */ ProtoLog.w(TEST_GROUP, "test %d %f " + "abc %s\n test", 100, 0.1, "test");
157 
158             }
159                 }
160             }
161             """.trimIndent()
162         /* ktlint-enable max-line-length */
163 
164         private const val PATH = "com.example.Test.java"
165     }
166 
167     private val processor: ProtoLogCallProcessor = Mockito.mock(ProtoLogCallProcessor::class.java)
168     private val implName = "org.example.ProtoLogImpl"
169     private val cacheName = "org.example.ProtoLogCache"
170     private val sourceJarWriter = SourceTransformer(implName, cacheName, processor)
171 
172     private fun <T> any(type: Class<T>): T = Mockito.any<T>(type)
173 
174     @Test
175     fun processClass_textEnabled() {
176         var code = StaticJavaParser.parse(TEST_CODE)
177 
178         Mockito.`when`(processor.process(any(CompilationUnit::class.java),
179                 any(ProtoLogCallVisitor::class.java), any(String::class.java)))
180                 .thenAnswer { invocation ->
181             val visitor = invocation.arguments[1] as ProtoLogCallVisitor
182 
183             visitor.processCall(code.findAll(MethodCallExpr::class.java)[0], "test %d %f",
184                     LogLevel.WARN, LogGroup("TEST_GROUP", true, true, "WM_TEST"))
185 
186             invocation.arguments[0] as CompilationUnit
187         }
188 
189         val out = sourceJarWriter.processClass(TEST_CODE, PATH, PATH, code)
190         code = StaticJavaParser.parse(out)
191 
192         val ifStmts = code.findAll(IfStmt::class.java)
193         assertEquals(1, ifStmts.size)
194         val ifStmt = ifStmts[0]
195         assertEquals("$cacheName.TEST_GROUP_enabled", ifStmt.condition.toString())
196         assertFalse(ifStmt.elseStmt.isPresent)
197         assertEquals(3, ifStmt.thenStmt.childNodes.size)
198         val methodCall = ifStmt.thenStmt.findAll(MethodCallExpr::class.java)[0] as MethodCallExpr
199         assertEquals(PROTO_LOG_IMPL_PATH, methodCall.scope.get().toString())
200         assertEquals("w", methodCall.name.asString())
201         assertEquals(6, methodCall.arguments.size)
202         assertEquals("TEST_GROUP", methodCall.arguments[0].toString())
203         assertEquals("1698911065", methodCall.arguments[1].toString())
204         assertEquals(0b1001.toString(), methodCall.arguments[2].toString())
205         assertEquals("\"test %d %f\"", methodCall.arguments[3].toString())
206         assertEquals("protoLogParam0", methodCall.arguments[4].toString())
207         assertEquals("protoLogParam1", methodCall.arguments[5].toString())
208         assertEquals(TRANSFORMED_CODE_TEXT_ENABLED, out)
209     }
210 
211     @Test
212     fun processClass_textEnabledMulticalls() {
213         var code = StaticJavaParser.parse(TEST_CODE_MULTICALLS)
214 
215         Mockito.`when`(processor.process(any(CompilationUnit::class.java),
216                 any(ProtoLogCallVisitor::class.java), any(String::class.java)))
217                 .thenAnswer { invocation ->
218             val visitor = invocation.arguments[1] as ProtoLogCallVisitor
219 
220             val calls = code.findAll(MethodCallExpr::class.java)
221             visitor.processCall(calls[0], "test %d %f",
222                     LogLevel.WARN, LogGroup("TEST_GROUP", true, true, "WM_TEST"))
223             visitor.processCall(calls[1], "test %d %f",
224                     LogLevel.WARN, LogGroup("TEST_GROUP", true, true, "WM_TEST"))
225             visitor.processCall(calls[2], "test %d %f",
226                     LogLevel.WARN, LogGroup("TEST_GROUP", true, true, "WM_TEST"))
227 
228             invocation.arguments[0] as CompilationUnit
229         }
230 
231         val out = sourceJarWriter.processClass(TEST_CODE_MULTICALLS, PATH, PATH, code)
232         code = StaticJavaParser.parse(out)
233 
234         val ifStmts = code.findAll(IfStmt::class.java)
235         assertEquals(3, ifStmts.size)
236         val ifStmt = ifStmts[1]
237         assertEquals("$cacheName.TEST_GROUP_enabled", ifStmt.condition.toString())
238         assertFalse(ifStmt.elseStmt.isPresent)
239         assertEquals(3, ifStmt.thenStmt.childNodes.size)
240         val methodCall = ifStmt.thenStmt.findAll(MethodCallExpr::class.java)[0] as MethodCallExpr
241         assertEquals(PROTO_LOG_IMPL_PATH, methodCall.scope.get().toString())
242         assertEquals("w", methodCall.name.asString())
243         assertEquals(6, methodCall.arguments.size)
244         assertEquals("TEST_GROUP", methodCall.arguments[0].toString())
245         assertEquals("1698911065", methodCall.arguments[1].toString())
246         assertEquals(0b1001.toString(), methodCall.arguments[2].toString())
247         assertEquals("\"test %d %f\"", methodCall.arguments[3].toString())
248         assertEquals("protoLogParam0", methodCall.arguments[4].toString())
249         assertEquals("protoLogParam1", methodCall.arguments[5].toString())
250         assertEquals(TRANSFORMED_CODE_MULTICALL_TEXT_ENABLED, out)
251     }
252 
253     @Test
254     fun processClass_textEnabledMultiline() {
255         var code = StaticJavaParser.parse(TEST_CODE_MULTILINE)
256 
257         Mockito.`when`(processor.process(any(CompilationUnit::class.java),
258                 any(ProtoLogCallVisitor::class.java), any(String::class.java)))
259                 .thenAnswer { invocation ->
260             val visitor = invocation.arguments[1] as ProtoLogCallVisitor
261 
262             visitor.processCall(code.findAll(MethodCallExpr::class.java)[0],
263                     "test %d %f abc %s\n test", LogLevel.WARN, LogGroup("TEST_GROUP",
264                     true, true, "WM_TEST"))
265 
266             invocation.arguments[0] as CompilationUnit
267         }
268 
269         val out = sourceJarWriter.processClass(TEST_CODE_MULTILINE, PATH, PATH, code)
270         code = StaticJavaParser.parse(out)
271 
272         val ifStmts = code.findAll(IfStmt::class.java)
273         assertEquals(1, ifStmts.size)
274         val ifStmt = ifStmts[0]
275         assertEquals("$cacheName.TEST_GROUP_enabled", ifStmt.condition.toString())
276         assertFalse(ifStmt.elseStmt.isPresent)
277         assertEquals(4, ifStmt.thenStmt.childNodes.size)
278         val methodCall = ifStmt.thenStmt.findAll(MethodCallExpr::class.java)[1] as MethodCallExpr
279         assertEquals(PROTO_LOG_IMPL_PATH, methodCall.scope.get().toString())
280         assertEquals("w", methodCall.name.asString())
281         assertEquals(7, methodCall.arguments.size)
282         assertEquals("TEST_GROUP", methodCall.arguments[0].toString())
283         assertEquals("1780316587", methodCall.arguments[1].toString())
284         assertEquals(0b001001.toString(), methodCall.arguments[2].toString())
285         assertEquals("protoLogParam0", methodCall.arguments[4].toString())
286         assertEquals("protoLogParam1", methodCall.arguments[5].toString())
287         assertEquals("protoLogParam2", methodCall.arguments[6].toString())
288         assertEquals(TRANSFORMED_CODE_MULTILINE_TEXT_ENABLED, out)
289     }
290 
291     @Test
292     fun processClass_noParams() {
293         var code = StaticJavaParser.parse(TEST_CODE_NO_PARAMS)
294 
295         Mockito.`when`(processor.process(any(CompilationUnit::class.java),
296                 any(ProtoLogCallVisitor::class.java), any(String::class.java)))
297                 .thenAnswer { invocation ->
298             val visitor = invocation.arguments[1] as ProtoLogCallVisitor
299 
300             visitor.processCall(code.findAll(MethodCallExpr::class.java)[0], "test",
301                     LogLevel.WARN, LogGroup("TEST_GROUP", true, true, "WM_TEST"))
302 
303             invocation.arguments[0] as CompilationUnit
304         }
305 
306         val out = sourceJarWriter.processClass(TEST_CODE_NO_PARAMS, PATH, PATH, code)
307         code = StaticJavaParser.parse(out)
308 
309         val ifStmts = code.findAll(IfStmt::class.java)
310         assertEquals(1, ifStmts.size)
311         val ifStmt = ifStmts[0]
312         assertEquals("$cacheName.TEST_GROUP_enabled", ifStmt.condition.toString())
313         assertFalse(ifStmt.elseStmt.isPresent)
314         assertEquals(1, ifStmt.thenStmt.childNodes.size)
315         val methodCall = ifStmt.thenStmt.findAll(MethodCallExpr::class.java)[0] as MethodCallExpr
316         assertEquals(PROTO_LOG_IMPL_PATH, methodCall.scope.get().toString())
317         assertEquals("w", methodCall.name.asString())
318         assertEquals(5, methodCall.arguments.size)
319         assertEquals("TEST_GROUP", methodCall.arguments[0].toString())
320         assertEquals("-1741986185", methodCall.arguments[1].toString())
321         assertEquals(0.toString(), methodCall.arguments[2].toString())
322         assertEquals(TRANSFORMED_CODE_NO_PARAMS, out)
323     }
324 
325     @Test
326     fun processClass_textDisabled() {
327         var code = StaticJavaParser.parse(TEST_CODE)
328 
329         Mockito.`when`(processor.process(any(CompilationUnit::class.java),
330                 any(ProtoLogCallVisitor::class.java), any(String::class.java)))
331                 .thenAnswer { invocation ->
332             val visitor = invocation.arguments[1] as ProtoLogCallVisitor
333 
334             visitor.processCall(code.findAll(MethodCallExpr::class.java)[0], "test %d %f",
335                     LogLevel.WARN, LogGroup("TEST_GROUP", true, false, "WM_TEST"))
336 
337             invocation.arguments[0] as CompilationUnit
338         }
339 
340         val out = sourceJarWriter.processClass(TEST_CODE, PATH, PATH, code)
341         code = StaticJavaParser.parse(out)
342 
343         val ifStmts = code.findAll(IfStmt::class.java)
344         assertEquals(1, ifStmts.size)
345         val ifStmt = ifStmts[0]
346         assertEquals("$cacheName.TEST_GROUP_enabled", ifStmt.condition.toString())
347         assertFalse(ifStmt.elseStmt.isPresent)
348         assertEquals(3, ifStmt.thenStmt.childNodes.size)
349         val methodCall = ifStmt.thenStmt.findAll(MethodCallExpr::class.java)[0] as MethodCallExpr
350         assertEquals(PROTO_LOG_IMPL_PATH, methodCall.scope.get().toString())
351         assertEquals("w", methodCall.name.asString())
352         assertEquals(6, methodCall.arguments.size)
353         assertEquals("TEST_GROUP", methodCall.arguments[0].toString())
354         assertEquals("1698911065", methodCall.arguments[1].toString())
355         assertEquals(0b1001.toString(), methodCall.arguments[2].toString())
356         assertEquals("null", methodCall.arguments[3].toString())
357         assertEquals("protoLogParam0", methodCall.arguments[4].toString())
358         assertEquals("protoLogParam1", methodCall.arguments[5].toString())
359         assertEquals(TRANSFORMED_CODE_TEXT_DISABLED, out)
360     }
361 
362     @Test
363     fun processClass_textDisabledMultiline() {
364         var code = StaticJavaParser.parse(TEST_CODE_MULTILINE)
365 
366         Mockito.`when`(processor.process(any(CompilationUnit::class.java),
367                 any(ProtoLogCallVisitor::class.java), any(String::class.java)))
368                 .thenAnswer { invocation ->
369             val visitor = invocation.arguments[1] as ProtoLogCallVisitor
370 
371             visitor.processCall(code.findAll(MethodCallExpr::class.java)[0],
372                     "test %d %f abc %s\n test", LogLevel.WARN, LogGroup("TEST_GROUP",
373                     true, false, "WM_TEST"))
374 
375             invocation.arguments[0] as CompilationUnit
376         }
377 
378         val out = sourceJarWriter.processClass(TEST_CODE_MULTILINE, PATH, PATH, code)
379         code = StaticJavaParser.parse(out)
380 
381         val ifStmts = code.findAll(IfStmt::class.java)
382         assertEquals(1, ifStmts.size)
383         val ifStmt = ifStmts[0]
384         assertEquals("$cacheName.TEST_GROUP_enabled", ifStmt.condition.toString())
385         assertFalse(ifStmt.elseStmt.isPresent)
386         assertEquals(4, ifStmt.thenStmt.childNodes.size)
387         val methodCall = ifStmt.thenStmt.findAll(MethodCallExpr::class.java)[1] as MethodCallExpr
388         assertEquals(PROTO_LOG_IMPL_PATH, methodCall.scope.get().toString())
389         assertEquals("w", methodCall.name.asString())
390         assertEquals(7, methodCall.arguments.size)
391         assertEquals("TEST_GROUP", methodCall.arguments[0].toString())
392         assertEquals("1780316587", methodCall.arguments[1].toString())
393         assertEquals(0b001001.toString(), methodCall.arguments[2].toString())
394         assertEquals("null", methodCall.arguments[3].toString())
395         assertEquals("protoLogParam0", methodCall.arguments[4].toString())
396         assertEquals("protoLogParam1", methodCall.arguments[5].toString())
397         assertEquals("protoLogParam2", methodCall.arguments[6].toString())
398         assertEquals(TRANSFORMED_CODE_MULTILINE_TEXT_DISABLED, out)
399     }
400 
401     @Test
402     fun processClass_disabled() {
403         var code = StaticJavaParser.parse(TEST_CODE)
404 
405         Mockito.`when`(processor.process(any(CompilationUnit::class.java),
406                 any(ProtoLogCallVisitor::class.java), any(String::class.java)))
407                 .thenAnswer { invocation ->
408             val visitor = invocation.arguments[1] as ProtoLogCallVisitor
409 
410             visitor.processCall(code.findAll(MethodCallExpr::class.java)[0], "test %d %f",
411                     LogLevel.WARN, LogGroup("TEST_GROUP", false, true, "WM_TEST"))
412 
413             invocation.arguments[0] as CompilationUnit
414         }
415 
416         val out = sourceJarWriter.processClass(TEST_CODE, PATH, PATH, code)
417         code = StaticJavaParser.parse(out)
418 
419         val ifStmts = code.findAll(IfStmt::class.java)
420         assertEquals(1, ifStmts.size)
421         val ifStmt = ifStmts[0]
422         assertEquals("false", ifStmt.condition.toString())
423         assertEquals(TRANSFORMED_CODE_DISABLED, out)
424     }
425 
426     @Test
427     fun processClass_disabledMultiline() {
428         var code = StaticJavaParser.parse(TEST_CODE_MULTILINE)
429 
430         Mockito.`when`(processor.process(any(CompilationUnit::class.java),
431                 any(ProtoLogCallVisitor::class.java), any(String::class.java)))
432                 .thenAnswer { invocation ->
433             val visitor = invocation.arguments[1] as ProtoLogCallVisitor
434 
435             visitor.processCall(code.findAll(MethodCallExpr::class.java)[0],
436                     "test %d %f abc %s\n test", LogLevel.WARN, LogGroup("TEST_GROUP",
437                     false, true, "WM_TEST"))
438 
439             invocation.arguments[0] as CompilationUnit
440         }
441 
442         val out = sourceJarWriter.processClass(TEST_CODE_MULTILINE, PATH, PATH, code)
443         code = StaticJavaParser.parse(out)
444 
445         val ifStmts = code.findAll(IfStmt::class.java)
446         assertEquals(1, ifStmts.size)
447         val ifStmt = ifStmts[0]
448         assertEquals("false", ifStmt.condition.toString())
449         assertEquals(TRANSFORMED_CODE_MULTILINE_DISABLED, out)
450     }
451 }
452