1 /* 2 * Copyright (C) 2019 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.protolog.tool 18 19 import com.github.javaparser.StaticJavaParser 20 import com.github.javaparser.ast.CompilationUnit 21 import com.github.javaparser.ast.expr.MethodCallExpr 22 import com.github.javaparser.ast.stmt.IfStmt 23 import org.junit.Assert.assertEquals 24 import org.junit.Assert.assertFalse 25 import org.junit.Test 26 import org.mockito.Mockito 27 28 class SourceTransformerTest { 29 companion object { 30 private const val PROTO_LOG_IMPL_PATH = "org.example.ProtoLogImpl" 31 32 /* ktlint-disable max-line-length */ 33 private val TEST_CODE = """ 34 package org.example; 35 36 class Test { 37 void test() { 38 ProtoLog.w(TEST_GROUP, "test %d %f", 100, 0.1); 39 } 40 } 41 """.trimIndent() 42 43 private val TEST_CODE_MULTILINE = """ 44 package org.example; 45 46 class Test { 47 void test() { 48 ProtoLog.w(TEST_GROUP, "test %d %f " + 49 "abc %s\n test", 100, 50 0.1, "test"); 51 } 52 } 53 """.trimIndent() 54 55 private val TEST_CODE_MULTICALLS = """ 56 package org.example; 57 58 class Test { 59 void test() { 60 ProtoLog.w(TEST_GROUP, "test %d %f", 100, 0.1); /* ProtoLog.w(TEST_GROUP, "test %d %f", 100, 0.1); */ ProtoLog.w(TEST_GROUP, "test %d %f", 100, 0.1); 61 ProtoLog.w(TEST_GROUP, "test %d %f", 100, 0.1); 62 } 63 } 64 """.trimIndent() 65 66 private val TEST_CODE_NO_PARAMS = """ 67 package org.example; 68 69 class Test { 70 void test() { 71 ProtoLog.w(TEST_GROUP, "test"); 72 } 73 } 74 """.trimIndent() 75 76 private val TRANSFORMED_CODE_TEXT_ENABLED = """ 77 package org.example; 78 79 class Test { 80 void test() { 81 if (org.example.ProtoLogCache.TEST_GROUP_enabled) { long protoLogParam0 = 100; double protoLogParam1 = 0.1; org.example.ProtoLogImpl.w(TEST_GROUP, 1698911065, 9, "test %d %f", protoLogParam0, protoLogParam1); } 82 } 83 } 84 """.trimIndent() 85 86 private val TRANSFORMED_CODE_MULTILINE_TEXT_ENABLED = """ 87 package org.example; 88 89 class Test { 90 void test() { 91 if (org.example.ProtoLogCache.TEST_GROUP_enabled) { long protoLogParam0 = 100; double protoLogParam1 = 0.1; String protoLogParam2 = String.valueOf("test"); org.example.ProtoLogImpl.w(TEST_GROUP, 1780316587, 9, "test %d %f " + "abc %s\n test", protoLogParam0, protoLogParam1, protoLogParam2); 92 93 } 94 } 95 } 96 """.trimIndent() 97 98 private val TRANSFORMED_CODE_MULTICALL_TEXT_ENABLED = """ 99 package org.example; 100 101 class Test { 102 void test() { 103 if (org.example.ProtoLogCache.TEST_GROUP_enabled) { long protoLogParam0 = 100; double protoLogParam1 = 0.1; org.example.ProtoLogImpl.w(TEST_GROUP, 1698911065, 9, "test %d %f", protoLogParam0, protoLogParam1); } /* ProtoLog.w(TEST_GROUP, "test %d %f", 100, 0.1); */ if (org.example.ProtoLogCache.TEST_GROUP_enabled) { long protoLogParam0 = 100; double protoLogParam1 = 0.1; org.example.ProtoLogImpl.w(TEST_GROUP, 1698911065, 9, "test %d %f", protoLogParam0, protoLogParam1); } 104 if (org.example.ProtoLogCache.TEST_GROUP_enabled) { long protoLogParam0 = 100; double protoLogParam1 = 0.1; org.example.ProtoLogImpl.w(TEST_GROUP, 1698911065, 9, "test %d %f", protoLogParam0, protoLogParam1); } 105 } 106 } 107 """.trimIndent() 108 109 private val TRANSFORMED_CODE_NO_PARAMS = """ 110 package org.example; 111 112 class Test { 113 void test() { 114 if (org.example.ProtoLogCache.TEST_GROUP_enabled) { org.example.ProtoLogImpl.w(TEST_GROUP, -1741986185, 0, "test", (Object[]) null); } 115 } 116 } 117 """.trimIndent() 118 119 private val TRANSFORMED_CODE_TEXT_DISABLED = """ 120 package org.example; 121 122 class Test { 123 void test() { 124 if (org.example.ProtoLogCache.TEST_GROUP_enabled) { long protoLogParam0 = 100; double protoLogParam1 = 0.1; org.example.ProtoLogImpl.w(TEST_GROUP, 1698911065, 9, null, protoLogParam0, protoLogParam1); } 125 } 126 } 127 """.trimIndent() 128 129 private val TRANSFORMED_CODE_MULTILINE_TEXT_DISABLED = """ 130 package org.example; 131 132 class Test { 133 void test() { 134 if (org.example.ProtoLogCache.TEST_GROUP_enabled) { long protoLogParam0 = 100; double protoLogParam1 = 0.1; String protoLogParam2 = String.valueOf("test"); org.example.ProtoLogImpl.w(TEST_GROUP, 1780316587, 9, null, protoLogParam0, protoLogParam1, protoLogParam2); 135 136 } 137 } 138 } 139 """.trimIndent() 140 141 private val TRANSFORMED_CODE_DISABLED = """ 142 package org.example; 143 144 class Test { 145 void test() { 146 if (false) { /* TEST_GROUP is disabled */ ProtoLog.w(TEST_GROUP, "test %d %f", 100, 0.1); } 147 } 148 } 149 """.trimIndent() 150 151 private val TRANSFORMED_CODE_MULTILINE_DISABLED = """ 152 package org.example; 153 154 class Test { 155 void test() { 156 if (false) { /* TEST_GROUP is disabled */ ProtoLog.w(TEST_GROUP, "test %d %f " + "abc %s\n test", 100, 0.1, "test"); 157 158 } 159 } 160 } 161 """.trimIndent() 162 /* ktlint-enable max-line-length */ 163 164 private const val PATH = "com.example.Test.java" 165 } 166 167 private val processor: ProtoLogCallProcessor = Mockito.mock(ProtoLogCallProcessor::class.java) 168 private val implName = "org.example.ProtoLogImpl" 169 private val cacheName = "org.example.ProtoLogCache" 170 private val sourceJarWriter = SourceTransformer(implName, cacheName, processor) 171 172 private fun <T> any(type: Class<T>): T = Mockito.any<T>(type) 173 174 @Test 175 fun processClass_textEnabled() { 176 var code = StaticJavaParser.parse(TEST_CODE) 177 178 Mockito.`when`(processor.process(any(CompilationUnit::class.java), 179 any(ProtoLogCallVisitor::class.java), any(String::class.java))) 180 .thenAnswer { invocation -> 181 val visitor = invocation.arguments[1] as ProtoLogCallVisitor 182 183 visitor.processCall(code.findAll(MethodCallExpr::class.java)[0], "test %d %f", 184 LogLevel.WARN, LogGroup("TEST_GROUP", true, true, "WM_TEST")) 185 186 invocation.arguments[0] as CompilationUnit 187 } 188 189 val out = sourceJarWriter.processClass(TEST_CODE, PATH, PATH, code) 190 code = StaticJavaParser.parse(out) 191 192 val ifStmts = code.findAll(IfStmt::class.java) 193 assertEquals(1, ifStmts.size) 194 val ifStmt = ifStmts[0] 195 assertEquals("$cacheName.TEST_GROUP_enabled", ifStmt.condition.toString()) 196 assertFalse(ifStmt.elseStmt.isPresent) 197 assertEquals(3, ifStmt.thenStmt.childNodes.size) 198 val methodCall = ifStmt.thenStmt.findAll(MethodCallExpr::class.java)[0] as MethodCallExpr 199 assertEquals(PROTO_LOG_IMPL_PATH, methodCall.scope.get().toString()) 200 assertEquals("w", methodCall.name.asString()) 201 assertEquals(6, methodCall.arguments.size) 202 assertEquals("TEST_GROUP", methodCall.arguments[0].toString()) 203 assertEquals("1698911065", methodCall.arguments[1].toString()) 204 assertEquals(0b1001.toString(), methodCall.arguments[2].toString()) 205 assertEquals("\"test %d %f\"", methodCall.arguments[3].toString()) 206 assertEquals("protoLogParam0", methodCall.arguments[4].toString()) 207 assertEquals("protoLogParam1", methodCall.arguments[5].toString()) 208 assertEquals(TRANSFORMED_CODE_TEXT_ENABLED, out) 209 } 210 211 @Test 212 fun processClass_textEnabledMulticalls() { 213 var code = StaticJavaParser.parse(TEST_CODE_MULTICALLS) 214 215 Mockito.`when`(processor.process(any(CompilationUnit::class.java), 216 any(ProtoLogCallVisitor::class.java), any(String::class.java))) 217 .thenAnswer { invocation -> 218 val visitor = invocation.arguments[1] as ProtoLogCallVisitor 219 220 val calls = code.findAll(MethodCallExpr::class.java) 221 visitor.processCall(calls[0], "test %d %f", 222 LogLevel.WARN, LogGroup("TEST_GROUP", true, true, "WM_TEST")) 223 visitor.processCall(calls[1], "test %d %f", 224 LogLevel.WARN, LogGroup("TEST_GROUP", true, true, "WM_TEST")) 225 visitor.processCall(calls[2], "test %d %f", 226 LogLevel.WARN, LogGroup("TEST_GROUP", true, true, "WM_TEST")) 227 228 invocation.arguments[0] as CompilationUnit 229 } 230 231 val out = sourceJarWriter.processClass(TEST_CODE_MULTICALLS, PATH, PATH, code) 232 code = StaticJavaParser.parse(out) 233 234 val ifStmts = code.findAll(IfStmt::class.java) 235 assertEquals(3, ifStmts.size) 236 val ifStmt = ifStmts[1] 237 assertEquals("$cacheName.TEST_GROUP_enabled", ifStmt.condition.toString()) 238 assertFalse(ifStmt.elseStmt.isPresent) 239 assertEquals(3, ifStmt.thenStmt.childNodes.size) 240 val methodCall = ifStmt.thenStmt.findAll(MethodCallExpr::class.java)[0] as MethodCallExpr 241 assertEquals(PROTO_LOG_IMPL_PATH, methodCall.scope.get().toString()) 242 assertEquals("w", methodCall.name.asString()) 243 assertEquals(6, methodCall.arguments.size) 244 assertEquals("TEST_GROUP", methodCall.arguments[0].toString()) 245 assertEquals("1698911065", methodCall.arguments[1].toString()) 246 assertEquals(0b1001.toString(), methodCall.arguments[2].toString()) 247 assertEquals("\"test %d %f\"", methodCall.arguments[3].toString()) 248 assertEquals("protoLogParam0", methodCall.arguments[4].toString()) 249 assertEquals("protoLogParam1", methodCall.arguments[5].toString()) 250 assertEquals(TRANSFORMED_CODE_MULTICALL_TEXT_ENABLED, out) 251 } 252 253 @Test 254 fun processClass_textEnabledMultiline() { 255 var code = StaticJavaParser.parse(TEST_CODE_MULTILINE) 256 257 Mockito.`when`(processor.process(any(CompilationUnit::class.java), 258 any(ProtoLogCallVisitor::class.java), any(String::class.java))) 259 .thenAnswer { invocation -> 260 val visitor = invocation.arguments[1] as ProtoLogCallVisitor 261 262 visitor.processCall(code.findAll(MethodCallExpr::class.java)[0], 263 "test %d %f abc %s\n test", LogLevel.WARN, LogGroup("TEST_GROUP", 264 true, true, "WM_TEST")) 265 266 invocation.arguments[0] as CompilationUnit 267 } 268 269 val out = sourceJarWriter.processClass(TEST_CODE_MULTILINE, PATH, PATH, code) 270 code = StaticJavaParser.parse(out) 271 272 val ifStmts = code.findAll(IfStmt::class.java) 273 assertEquals(1, ifStmts.size) 274 val ifStmt = ifStmts[0] 275 assertEquals("$cacheName.TEST_GROUP_enabled", ifStmt.condition.toString()) 276 assertFalse(ifStmt.elseStmt.isPresent) 277 assertEquals(4, ifStmt.thenStmt.childNodes.size) 278 val methodCall = ifStmt.thenStmt.findAll(MethodCallExpr::class.java)[1] as MethodCallExpr 279 assertEquals(PROTO_LOG_IMPL_PATH, methodCall.scope.get().toString()) 280 assertEquals("w", methodCall.name.asString()) 281 assertEquals(7, methodCall.arguments.size) 282 assertEquals("TEST_GROUP", methodCall.arguments[0].toString()) 283 assertEquals("1780316587", methodCall.arguments[1].toString()) 284 assertEquals(0b001001.toString(), methodCall.arguments[2].toString()) 285 assertEquals("protoLogParam0", methodCall.arguments[4].toString()) 286 assertEquals("protoLogParam1", methodCall.arguments[5].toString()) 287 assertEquals("protoLogParam2", methodCall.arguments[6].toString()) 288 assertEquals(TRANSFORMED_CODE_MULTILINE_TEXT_ENABLED, out) 289 } 290 291 @Test 292 fun processClass_noParams() { 293 var code = StaticJavaParser.parse(TEST_CODE_NO_PARAMS) 294 295 Mockito.`when`(processor.process(any(CompilationUnit::class.java), 296 any(ProtoLogCallVisitor::class.java), any(String::class.java))) 297 .thenAnswer { invocation -> 298 val visitor = invocation.arguments[1] as ProtoLogCallVisitor 299 300 visitor.processCall(code.findAll(MethodCallExpr::class.java)[0], "test", 301 LogLevel.WARN, LogGroup("TEST_GROUP", true, true, "WM_TEST")) 302 303 invocation.arguments[0] as CompilationUnit 304 } 305 306 val out = sourceJarWriter.processClass(TEST_CODE_NO_PARAMS, PATH, PATH, code) 307 code = StaticJavaParser.parse(out) 308 309 val ifStmts = code.findAll(IfStmt::class.java) 310 assertEquals(1, ifStmts.size) 311 val ifStmt = ifStmts[0] 312 assertEquals("$cacheName.TEST_GROUP_enabled", ifStmt.condition.toString()) 313 assertFalse(ifStmt.elseStmt.isPresent) 314 assertEquals(1, ifStmt.thenStmt.childNodes.size) 315 val methodCall = ifStmt.thenStmt.findAll(MethodCallExpr::class.java)[0] as MethodCallExpr 316 assertEquals(PROTO_LOG_IMPL_PATH, methodCall.scope.get().toString()) 317 assertEquals("w", methodCall.name.asString()) 318 assertEquals(5, methodCall.arguments.size) 319 assertEquals("TEST_GROUP", methodCall.arguments[0].toString()) 320 assertEquals("-1741986185", methodCall.arguments[1].toString()) 321 assertEquals(0.toString(), methodCall.arguments[2].toString()) 322 assertEquals(TRANSFORMED_CODE_NO_PARAMS, out) 323 } 324 325 @Test 326 fun processClass_textDisabled() { 327 var code = StaticJavaParser.parse(TEST_CODE) 328 329 Mockito.`when`(processor.process(any(CompilationUnit::class.java), 330 any(ProtoLogCallVisitor::class.java), any(String::class.java))) 331 .thenAnswer { invocation -> 332 val visitor = invocation.arguments[1] as ProtoLogCallVisitor 333 334 visitor.processCall(code.findAll(MethodCallExpr::class.java)[0], "test %d %f", 335 LogLevel.WARN, LogGroup("TEST_GROUP", true, false, "WM_TEST")) 336 337 invocation.arguments[0] as CompilationUnit 338 } 339 340 val out = sourceJarWriter.processClass(TEST_CODE, PATH, PATH, code) 341 code = StaticJavaParser.parse(out) 342 343 val ifStmts = code.findAll(IfStmt::class.java) 344 assertEquals(1, ifStmts.size) 345 val ifStmt = ifStmts[0] 346 assertEquals("$cacheName.TEST_GROUP_enabled", ifStmt.condition.toString()) 347 assertFalse(ifStmt.elseStmt.isPresent) 348 assertEquals(3, ifStmt.thenStmt.childNodes.size) 349 val methodCall = ifStmt.thenStmt.findAll(MethodCallExpr::class.java)[0] as MethodCallExpr 350 assertEquals(PROTO_LOG_IMPL_PATH, methodCall.scope.get().toString()) 351 assertEquals("w", methodCall.name.asString()) 352 assertEquals(6, methodCall.arguments.size) 353 assertEquals("TEST_GROUP", methodCall.arguments[0].toString()) 354 assertEquals("1698911065", methodCall.arguments[1].toString()) 355 assertEquals(0b1001.toString(), methodCall.arguments[2].toString()) 356 assertEquals("null", methodCall.arguments[3].toString()) 357 assertEquals("protoLogParam0", methodCall.arguments[4].toString()) 358 assertEquals("protoLogParam1", methodCall.arguments[5].toString()) 359 assertEquals(TRANSFORMED_CODE_TEXT_DISABLED, out) 360 } 361 362 @Test 363 fun processClass_textDisabledMultiline() { 364 var code = StaticJavaParser.parse(TEST_CODE_MULTILINE) 365 366 Mockito.`when`(processor.process(any(CompilationUnit::class.java), 367 any(ProtoLogCallVisitor::class.java), any(String::class.java))) 368 .thenAnswer { invocation -> 369 val visitor = invocation.arguments[1] as ProtoLogCallVisitor 370 371 visitor.processCall(code.findAll(MethodCallExpr::class.java)[0], 372 "test %d %f abc %s\n test", LogLevel.WARN, LogGroup("TEST_GROUP", 373 true, false, "WM_TEST")) 374 375 invocation.arguments[0] as CompilationUnit 376 } 377 378 val out = sourceJarWriter.processClass(TEST_CODE_MULTILINE, PATH, PATH, code) 379 code = StaticJavaParser.parse(out) 380 381 val ifStmts = code.findAll(IfStmt::class.java) 382 assertEquals(1, ifStmts.size) 383 val ifStmt = ifStmts[0] 384 assertEquals("$cacheName.TEST_GROUP_enabled", ifStmt.condition.toString()) 385 assertFalse(ifStmt.elseStmt.isPresent) 386 assertEquals(4, ifStmt.thenStmt.childNodes.size) 387 val methodCall = ifStmt.thenStmt.findAll(MethodCallExpr::class.java)[1] as MethodCallExpr 388 assertEquals(PROTO_LOG_IMPL_PATH, methodCall.scope.get().toString()) 389 assertEquals("w", methodCall.name.asString()) 390 assertEquals(7, methodCall.arguments.size) 391 assertEquals("TEST_GROUP", methodCall.arguments[0].toString()) 392 assertEquals("1780316587", methodCall.arguments[1].toString()) 393 assertEquals(0b001001.toString(), methodCall.arguments[2].toString()) 394 assertEquals("null", methodCall.arguments[3].toString()) 395 assertEquals("protoLogParam0", methodCall.arguments[4].toString()) 396 assertEquals("protoLogParam1", methodCall.arguments[5].toString()) 397 assertEquals("protoLogParam2", methodCall.arguments[6].toString()) 398 assertEquals(TRANSFORMED_CODE_MULTILINE_TEXT_DISABLED, out) 399 } 400 401 @Test 402 fun processClass_disabled() { 403 var code = StaticJavaParser.parse(TEST_CODE) 404 405 Mockito.`when`(processor.process(any(CompilationUnit::class.java), 406 any(ProtoLogCallVisitor::class.java), any(String::class.java))) 407 .thenAnswer { invocation -> 408 val visitor = invocation.arguments[1] as ProtoLogCallVisitor 409 410 visitor.processCall(code.findAll(MethodCallExpr::class.java)[0], "test %d %f", 411 LogLevel.WARN, LogGroup("TEST_GROUP", false, true, "WM_TEST")) 412 413 invocation.arguments[0] as CompilationUnit 414 } 415 416 val out = sourceJarWriter.processClass(TEST_CODE, PATH, PATH, code) 417 code = StaticJavaParser.parse(out) 418 419 val ifStmts = code.findAll(IfStmt::class.java) 420 assertEquals(1, ifStmts.size) 421 val ifStmt = ifStmts[0] 422 assertEquals("false", ifStmt.condition.toString()) 423 assertEquals(TRANSFORMED_CODE_DISABLED, out) 424 } 425 426 @Test 427 fun processClass_disabledMultiline() { 428 var code = StaticJavaParser.parse(TEST_CODE_MULTILINE) 429 430 Mockito.`when`(processor.process(any(CompilationUnit::class.java), 431 any(ProtoLogCallVisitor::class.java), any(String::class.java))) 432 .thenAnswer { invocation -> 433 val visitor = invocation.arguments[1] as ProtoLogCallVisitor 434 435 visitor.processCall(code.findAll(MethodCallExpr::class.java)[0], 436 "test %d %f abc %s\n test", LogLevel.WARN, LogGroup("TEST_GROUP", 437 false, true, "WM_TEST")) 438 439 invocation.arguments[0] as CompilationUnit 440 } 441 442 val out = sourceJarWriter.processClass(TEST_CODE_MULTILINE, PATH, PATH, code) 443 code = StaticJavaParser.parse(out) 444 445 val ifStmts = code.findAll(IfStmt::class.java) 446 assertEquals(1, ifStmts.size) 447 val ifStmt = ifStmts[0] 448 assertEquals("false", ifStmt.condition.toString()) 449 assertEquals(TRANSFORMED_CODE_MULTILINE_DISABLED, out) 450 } 451 } 452