1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 import com.squareup.javapoet.ClassName
18 import com.squareup.javapoet.FieldSpec
19 import com.squareup.javapoet.JavaFile
20 import com.squareup.javapoet.MethodSpec
21 import com.squareup.javapoet.NameAllocator
22 import com.squareup.javapoet.ParameterSpec
23 import com.squareup.javapoet.TypeSpec
24 import java.io.File
25 import java.io.FileInputStream
26 import java.io.FileNotFoundException
27 import java.io.FileOutputStream
28 import java.io.IOException
29 import java.nio.charset.StandardCharsets
30 import java.time.Year
31 import java.util.Objects
32 import javax.lang.model.element.Modifier
33 
34 // JavaPoet only supports line comments, and can't add a newline after file level comments.
35 val FILE_HEADER = """
36     /*
37      * Copyright (C) ${Year.now().value} The Android Open Source Project
38      *
39      * Licensed under the Apache License, Version 2.0 (the "License");
40      * you may not use this file except in compliance with the License.
41      * You may obtain a copy of the License at
42      *
43      *      http://www.apache.org/licenses/LICENSE-2.0
44      *
45      * Unless required by applicable law or agreed to in writing, software
46      * distributed under the License is distributed on an "AS IS" BASIS,
47      * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
48      * See the License for the specific language governing permissions and
49      * limitations under the License.
50      */
51 
52     // Generated by xmlpersistence. DO NOT MODIFY!
53     // CHECKSTYLE:OFF Generated code
54     // @formatter:off
55 """.trimIndent() + "\n\n"
56 
57 private val atomicFileType = ClassName.get("android.util", "AtomicFile")
58 
59 fun generate(persistence: PersistenceInfo): JavaFile {
60     val distinctClassFields = persistence.root.allClassFields.distinctBy { it.type }
61     val type = TypeSpec.classBuilder(persistence.name)
62         .addJavadoc(
63             """
64                 Generated class implementing XML persistence for${'$'}W{@link $1T}.
65                 <p>
66                 This class provides atomicity for persistence via {@link $2T}, however it does not provide
67                 thread safety, so please bring your own synchronization mechanism.
68             """.trimIndent(), persistence.root.type, atomicFileType
69         )
70         .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
71         .addField(generateFileField())
72         .addMethod(generateConstructor())
73         .addMethod(generateReadMethod(persistence.root))
74         .addMethod(generateParseMethod(persistence.root))
75         .addMethods(distinctClassFields.map { generateParseClassMethod(it) })
76         .addMethod(generateWriteMethod(persistence.root))
77         .addMethod(generateSerializeMethod(persistence.root))
78         .addMethods(distinctClassFields.map { generateSerializeClassMethod(it) })
79         .addMethod(generateDeleteMethod())
80         .build()
81     return JavaFile.builder(persistence.root.type.packageName(), type)
82         .skipJavaLangImports(true)
83         .indent("    ")
84         .build()
85 }
86 
87 private val nonNullType = ClassName.get("android.annotation", "NonNull")
88 
89 private fun generateFileField(): FieldSpec =
90     FieldSpec.builder(atomicFileType, "mFile", Modifier.PRIVATE, Modifier.FINAL)
91         .addAnnotation(nonNullType)
92         .build()
93 
94 private fun generateConstructor(): MethodSpec =
95     MethodSpec.constructorBuilder()
96         .addJavadoc(
97             """
98                 Create an instance of this class.
99 
100                 @param file the XML file for persistence
101             """.trimIndent()
102         )
103         .addModifiers(Modifier.PUBLIC)
104         .addParameter(
105             ParameterSpec.builder(File::class.java, "file").addAnnotation(nonNullType).build()
106         )
107         .addStatement("mFile = new \$1T(file)", atomicFileType)
108         .build()
109 
110 private val nullableType = ClassName.get("android.annotation", "Nullable")
111 
112 private val xmlPullParserType = ClassName.get("org.xmlpull.v1", "XmlPullParser")
113 
114 private val xmlType = ClassName.get("android.util", "Xml")
115 
116 private val xmlPullParserExceptionType = ClassName.get("org.xmlpull.v1", "XmlPullParserException")
117 
118 private fun generateReadMethod(rootField: ClassFieldInfo): MethodSpec =
119     MethodSpec.methodBuilder("read")
120         .addJavadoc(
121             """
122                 Read${'$'}W{@link $1T}${'$'}Wfrom${'$'}Wthe${'$'}WXML${'$'}Wfile.
123 
124                 @return the persisted${'$'}W{@link $1T},${'$'}Wor${'$'}W{@code null}${'$'}Wif${'$'}Wthe${'$'}WXML${'$'}Wfile${'$'}Wdoesn't${'$'}Wexist
125                 @throws IllegalArgumentException if an error occurred while reading
126             """.trimIndent(), rootField.type
127         )
128         .addAnnotation(nullableType)
129         .addModifiers(Modifier.PUBLIC)
130         .returns(rootField.type)
131         .addControlFlow("try (\$1T inputStream = mFile.openRead())", FileInputStream::class.java) {
132             addStatement("final \$1T parser = \$2T.newPullParser()", xmlPullParserType, xmlType)
133             addStatement("parser.setInput(inputStream, null)")
134             addStatement("return parse(parser)")
135             nextControlFlow("catch (\$1T e)", FileNotFoundException::class.java)
136             addStatement("return null")
137             nextControlFlow(
138                 "catch (\$1T | \$2T e)", IOException::class.java, xmlPullParserExceptionType
139             )
140             addStatement("throw new IllegalArgumentException(e)")
141         }
142         .build()
143 
144 private val ClassFieldInfo.allClassFields: List<ClassFieldInfo>
145     get() =
146         mutableListOf<ClassFieldInfo>().apply {
147             this += this@allClassFields
148             for (field in fields) {
149                 when (field) {
150                     is ClassFieldInfo -> this += field.allClassFields
151                     is ListFieldInfo -> this += field.element.allClassFields
152                 }
153             }
154         }
155 
156 private fun generateParseMethod(rootField: ClassFieldInfo): MethodSpec =
157     MethodSpec.methodBuilder("parse")
158         .addAnnotation(nonNullType)
159         .addModifiers(Modifier.PRIVATE, Modifier.STATIC)
160         .returns(rootField.type)
161         .addParameter(
162             ParameterSpec.builder(xmlPullParserType, "parser").addAnnotation(nonNullType).build()
163         )
164         .addExceptions(listOf(ClassName.get(IOException::class.java), xmlPullParserExceptionType))
165         .apply {
166             addStatement("int type")
167             addStatement("int depth")
168             addStatement("int innerDepth = parser.getDepth() + 1")
169             addControlFlow(
170                 "while ((type = parser.next()) != \$1T.END_DOCUMENT\$W"
171                     + "&& ((depth = parser.getDepth()) >= innerDepth || type != \$1T.END_TAG))",
172                 xmlPullParserType
173             ) {
174                 addControlFlow(
175                     "if (depth > innerDepth || type != \$1T.START_TAG)", xmlPullParserType
176                 ) {
177                     addStatement("continue")
178                 }
179                 addControlFlow(
180                     "if (\$1T.equals(parser.getName(),\$W\$2S))", Objects::class.java,
181                     rootField.tagName
182                 ) {
183                     addStatement("return \$1L(parser)", rootField.parseMethodName)
184                 }
185             }
186             addStatement(
187                 "throw new IllegalArgumentException(\$1S)",
188                 "Missing root tag <${rootField.tagName}>"
189             )
190         }
191         .build()
192 
193 private fun generateParseClassMethod(classField: ClassFieldInfo): MethodSpec =
194     MethodSpec.methodBuilder(classField.parseMethodName)
195         .addAnnotation(nonNullType)
196         .addModifiers(Modifier.PRIVATE, Modifier.STATIC)
197         .returns(classField.type)
198         .addParameter(
199             ParameterSpec.builder(xmlPullParserType, "parser").addAnnotation(nonNullType).build()
200         )
201         .apply {
202             val (attributeFields, tagFields) = classField.fields
203                 .partition { it is PrimitiveFieldInfo || it is StringFieldInfo }
204             if (tagFields.isNotEmpty()) {
205                 addExceptions(
206                     listOf(ClassName.get(IOException::class.java), xmlPullParserExceptionType)
207                 )
208             }
209             val nameAllocator = NameAllocator().apply {
210                 newName("parser")
211                 newName("type")
212                 newName("depth")
213                 newName("innerDepth")
214             }
215             for (field in attributeFields) {
216                 val variableName = nameAllocator.newName(field.variableName, field)
217                 when (field) {
218                     is PrimitiveFieldInfo -> {
219                         val stringVariableName =
220                             nameAllocator.newName("${field.variableName}String")
221                         addStatement(
222                             "final String \$1L =\$Wparser.getAttributeValue(null,\$W\$2S)",
223                             stringVariableName, field.attributeName
224                         )
225                         if (field.isRequired) {
226                             addControlFlow("if (\$1L == null)", stringVariableName) {
227                                 addStatement(
228                                     "throw new IllegalArgumentException(\$1S)",
229                                     "Missing attribute \"${field.attributeName}\""
230                                 )
231                             }
232                         }
233                         val boxedType = field.type.box()
234                         val parseTypeMethodName = if (field.type.isPrimitive) {
235                             "parse${field.type.toString().capitalize()}"
236                         } else {
237                             "valueOf"
238                         }
239                         if (field.isRequired) {
240                             addStatement(
241                                 "final \$1T \$2L =\$W\$3T.\$4L($5L)", field.type, variableName,
242                                 boxedType, parseTypeMethodName, stringVariableName
243                             )
244                         } else {
245                             addStatement(
246                                 "final \$1T \$2L =\$W$3L != null ?\$W\$4T.\$5L($3L)\$W: null",
247                                 field.type, variableName, stringVariableName, boxedType,
248                                 parseTypeMethodName
249                             )
250                         }
251                     }
252                     is StringFieldInfo ->
253                         addStatement(
254                             "final String \$1L =\$Wparser.getAttributeValue(null,\$W\$2S)",
255                             variableName, field.attributeName
256                         )
257                     else -> error(field)
258                 }
259             }
260             if (tagFields.isNotEmpty()) {
261                 for (field in tagFields) {
262                     val variableName = nameAllocator.newName(field.variableName, field)
263                     when (field) {
264                         is ClassFieldInfo ->
265                             addStatement("\$1T \$2L =\$Wnull", field.type, variableName)
266                         is ListFieldInfo ->
267                             addStatement(
268                                 "final \$1T \$2L =\$Wnew \$3T<>()", field.type, variableName,
269                                 ArrayList::class.java
270                             )
271                         else -> error(field)
272                     }
273                 }
274                 addStatement("int type")
275                 addStatement("int depth")
276                 addStatement("int innerDepth = parser.getDepth() + 1")
277                 addControlFlow(
278                     "while ((type = parser.next()) != \$1T.END_DOCUMENT\$W"
279                         + "&& ((depth = parser.getDepth()) >= innerDepth || type != \$1T.END_TAG))",
280                     xmlPullParserType
281                 ) {
282                     addControlFlow(
283                         "if (depth > innerDepth || type != \$1T.START_TAG)", xmlPullParserType
284                     ) {
285                         addStatement("continue")
286                     }
287                     addControlFlow("switch (parser.getName())") {
288                         for (field in tagFields) {
289                             addControlFlow("case \$1S:", field.tagName) {
290                                 val variableName = nameAllocator.get(field)
291                                 when (field) {
292                                     is ClassFieldInfo -> {
293                                         addControlFlow("if (\$1L != null)", variableName) {
294                                             addStatement(
295                                                 "throw new IllegalArgumentException(\$1S)",
296                                                 "Duplicate tag \"${field.tagName}\""
297                                             )
298                                         }
299                                         addStatement(
300                                             "\$1L =\$W\$2L(parser)", variableName,
301                                             field.parseMethodName
302                                         )
303                                         addStatement("break")
304                                     }
305                                     is ListFieldInfo -> {
306                                         val elementNameAllocator = nameAllocator.clone()
307                                         val elementVariableName = elementNameAllocator.newName(
308                                             field.element.xmlName!!.toLowerCamelCase()
309                                         )
310                                         addStatement(
311                                             "final \$1T \$2L =\$W\$3L(parser)", field.element.type,
312                                             elementVariableName, field.element.parseMethodName
313                                         )
314                                         addStatement(
315                                             "\$1L.add(\$2L)", variableName, elementVariableName
316                                         )
317                                         addStatement("break")
318                                     }
319                                     else -> error(field)
320                                 }
321                             }
322                         }
323                     }
324                 }
325             }
326             for (field in tagFields.filter { it is ClassFieldInfo && it.isRequired }) {
327                 addControlFlow("if ($1L == null)", nameAllocator.get(field)) {
328                     addStatement(
329                         "throw new IllegalArgumentException(\$1S)", "Missing tag <${field.tagName}>"
330                     )
331                 }
332             }
333             addStatement(
334                 classField.fields.joinToString(",\$W", "return new \$1T(", ")") {
335                     nameAllocator.get(it)
336                 }, classField.type
337             )
338         }
339         .build()
340 
341 private val ClassFieldInfo.parseMethodName: String
342     get() = "parse${type.simpleName().toUpperCamelCase()}"
343 
344 private val xmlSerializerType = ClassName.get("org.xmlpull.v1", "XmlSerializer")
345 
346 private fun generateWriteMethod(rootField: ClassFieldInfo): MethodSpec =
347     MethodSpec.methodBuilder("write")
348         .apply {
349             val nameAllocator = NameAllocator().apply {
350                 newName("outputStream")
351                 newName("serializer")
352             }
353             val parameterName = nameAllocator.newName(rootField.variableName)
354             addJavadoc(
355                 """
356                     Write${'$'}W{@link $1T}${'$'}Wto${'$'}Wthe${'$'}WXML${'$'}Wfile.
357 
358                     @param $2L the${'$'}W{@link ${'$'}1T}${'$'}Wto${'$'}Wpersist
359                 """.trimIndent(), rootField.type, parameterName
360             )
361             addAnnotation(nullableType)
362             addModifiers(Modifier.PUBLIC)
363             addParameter(
364                 ParameterSpec.builder(rootField.type, parameterName)
365                     .addAnnotation(nonNullType)
366                     .build()
367             )
368             addStatement("\$1T outputStream = null", FileOutputStream::class.java)
369             addControlFlow("try") {
370                 addStatement("outputStream = mFile.startWrite()")
371                 addStatement(
372                     "final \$1T serializer =\$W\$2T.newSerializer()", xmlSerializerType, xmlType
373                 )
374                 addStatement(
375                     "serializer.setOutput(outputStream, \$1T.UTF_8.name())",
376                     StandardCharsets::class.java
377                 )
378                 addStatement(
379                     "serializer.setFeature(\$1S, true)",
380                     "http://xmlpull.org/v1/doc/features.html#indent-output"
381                 )
382                 addStatement("serializer.startDocument(null, true)")
383                 addStatement("serialize(serializer,\$W\$1L)", parameterName)
384                 addStatement("serializer.endDocument()")
385                 addStatement("mFile.finishWrite(outputStream)")
386                 nextControlFlow("catch (Exception e)")
387                 addStatement("e.printStackTrace()")
388                 addStatement("mFile.failWrite(outputStream)")
389             }
390         }
391         .build()
392 
393 private fun generateSerializeMethod(rootField: ClassFieldInfo): MethodSpec =
394     MethodSpec.methodBuilder("serialize")
395         .addModifiers(Modifier.PRIVATE, Modifier.STATIC)
396         .addParameter(
397             ParameterSpec.builder(xmlSerializerType, "serializer")
398                 .addAnnotation(nonNullType)
399                 .build()
400         )
401         .apply {
402             val nameAllocator = NameAllocator().apply { newName("serializer") }
403             val parameterName = nameAllocator.newName(rootField.variableName)
404             addParameter(
405                 ParameterSpec.builder(rootField.type, parameterName)
406                     .addAnnotation(nonNullType)
407                     .build()
408             )
409             addException(IOException::class.java)
410             addStatement("serializer.startTag(null, \$1S)", rootField.tagName)
411             addStatement("\$1L(serializer, \$2L)", rootField.serializeMethodName, parameterName)
412             addStatement("serializer.endTag(null, \$1S)", rootField.tagName)
413         }
414         .build()
415 
416 private fun generateSerializeClassMethod(classField: ClassFieldInfo): MethodSpec =
417     MethodSpec.methodBuilder(classField.serializeMethodName)
418         .addModifiers(Modifier.PRIVATE, Modifier.STATIC)
419         .addParameter(
420             ParameterSpec.builder(xmlSerializerType, "serializer")
421                 .addAnnotation(nonNullType)
422                 .build()
423         )
424         .apply {
425             val nameAllocator = NameAllocator().apply {
426                 newName("serializer")
427                 newName("i")
428             }
429             val parameterName = nameAllocator.newName(classField.serializeParameterName)
430             addParameter(
431                 ParameterSpec.builder(classField.type, parameterName)
432                     .addAnnotation(nonNullType)
433                     .build()
434             )
435             addException(IOException::class.java)
436             val (attributeFields, tagFields) = classField.fields
437                 .partition { it is PrimitiveFieldInfo || it is StringFieldInfo }
438             for (field in attributeFields) {
439                 val variableName = "$parameterName.${field.name}"
440                 if (!field.isRequired) {
441                     beginControlFlow("if (\$1L != null)", variableName)
442                 }
443                 when (field) {
444                     is PrimitiveFieldInfo -> {
445                         if (field.isRequired && !field.type.isPrimitive) {
446                             addControlFlow("if (\$1L == null)", variableName) {
447                                 addStatement(
448                                     "throw new IllegalArgumentException(\$1S)",
449                                     "Field \"${field.name}\" is null"
450                                 )
451                             }
452                         }
453                         val stringVariableName =
454                             nameAllocator.newName("${field.variableName}String")
455                         addStatement(
456                             "final String \$1L =\$WString.valueOf(\$2L)", stringVariableName,
457                             variableName
458                         )
459                         addStatement(
460                             "serializer.attribute(null, \$1S, \$2L)", field.attributeName,
461                             stringVariableName
462                         )
463                     }
464                     is StringFieldInfo -> {
465                         if (field.isRequired) {
466                             addControlFlow("if (\$1L == null)", variableName) {
467                                 addStatement(
468                                     "throw new IllegalArgumentException(\$1S)",
469                                     "Field \"${field.name}\" is null"
470                                 )
471                             }
472                         }
473                         addStatement(
474                             "serializer.attribute(null, \$1S, \$2L)", field.attributeName,
475                             variableName
476                         )
477                     }
478                     else -> error(field)
479                 }
480                 if (!field.isRequired) {
481                     endControlFlow()
482                 }
483             }
484             for (field in tagFields) {
485                 val variableName = "$parameterName.${field.name}"
486                 if (field.isRequired) {
487                     addControlFlow("if (\$1L == null)", variableName) {
488                         addStatement(
489                             "throw new IllegalArgumentException(\$1S)",
490                             "Field \"${field.name}\" is null"
491                         )
492                     }
493                 }
494                 when (field) {
495                     is ClassFieldInfo -> {
496                         addStatement("serializer.startTag(null, \$1S)", field.tagName)
497                         addStatement(
498                             "\$1L(serializer, \$2L)", field.serializeMethodName, variableName
499                         )
500                         addStatement("serializer.endTag(null, \$1S)", field.tagName)
501                     }
502                     is ListFieldInfo -> {
503                         val sizeVariableName = nameAllocator.newName("${field.variableName}Size")
504                         addStatement(
505                             "final int \$1L =\$W\$2L.size()", sizeVariableName, variableName
506                         )
507                         addControlFlow("for (int i = 0;\$Wi < \$1L;\$Wi++)", sizeVariableName) {
508                             val elementNameAllocator = nameAllocator.clone()
509                             val elementVariableName = elementNameAllocator.newName(
510                                 field.element.xmlName!!.toLowerCamelCase()
511                             )
512                             addStatement(
513                                 "final \$1T \$2L =\$W\$3L.get(i)", field.element.type,
514                                 elementVariableName, variableName
515                             )
516                             addControlFlow("if (\$1L == null)", elementVariableName) {
517                                 addStatement(
518                                     "throw new IllegalArgumentException(\$1S\$W+ i\$W+ \$2S)",
519                                     "Field element \"${field.name}[", "]\" is null"
520                                 )
521                             }
522                             addStatement("serializer.startTag(null, \$1S)", field.element.tagName)
523                             addStatement(
524                                 "\$1L(serializer,\$W\$2L)", field.element.serializeMethodName,
525                                 elementVariableName
526                             )
527                             addStatement("serializer.endTag(null, \$1S)", field.element.tagName)
528                         }
529                     }
530                     else -> error(field)
531                 }
532             }
533         }
534         .build()
535 
536 private val ClassFieldInfo.serializeMethodName: String
537     get() = "serialize${type.simpleName().toUpperCamelCase()}"
538 
539 private val ClassFieldInfo.serializeParameterName: String
540     get() = type.simpleName().toLowerCamelCase()
541 
542 private val FieldInfo.variableName: String
543     get() = name.toLowerCamelCase()
544 
545 private val FieldInfo.attributeName: String
546     get() {
547         check(this is PrimitiveFieldInfo || this is StringFieldInfo)
548         return xmlNameOrName.toLowerCamelCase()
549     }
550 
551 private val FieldInfo.tagName: String
552     get() {
553         check(this is ClassFieldInfo || this is ListFieldInfo)
554         return xmlNameOrName.toLowerKebabCase()
555     }
556 
557 private val FieldInfo.xmlNameOrName: String
558     get() = xmlName ?: name
559 
560 private fun generateDeleteMethod(): MethodSpec =
561     MethodSpec.methodBuilder("delete")
562         .addJavadoc("Delete the XML file, if any.")
563         .addModifiers(Modifier.PUBLIC)
564         .addStatement("mFile.delete()")
565         .build()
566 
567 private inline fun MethodSpec.Builder.addControlFlow(
568     controlFlow: String,
569     vararg args: Any,
570     block: MethodSpec.Builder.() -> Unit
571 ): MethodSpec.Builder {
572     beginControlFlow(controlFlow, *args)
573     block()
574     endControlFlow()
575     return this
576 }
577