1#!/usr/bin/env python
2#
3# Copyright (C) 2013 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#      http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17
18"""Unit testing checker.py."""
19
20# Disable check for function names to avoid errors based on old code
21# pylint: disable-msg=invalid-name
22
23from __future__ import absolute_import
24
25import array
26import collections
27import hashlib
28import io
29import itertools
30import os
31import unittest
32
33from six.moves import zip
34
35import mock  # pylint: disable=import-error
36
37from update_payload import checker
38from update_payload import common
39from update_payload import test_utils
40from update_payload import update_metadata_pb2
41from update_payload.error import PayloadError
42from update_payload.payload import Payload  # Avoid name conflicts later.
43
44
45def _OpTypeByName(op_name):
46  """Returns the type of an operation from its name."""
47  op_name_to_type = {
48      'REPLACE': common.OpType.REPLACE,
49      'REPLACE_BZ': common.OpType.REPLACE_BZ,
50      'SOURCE_COPY': common.OpType.SOURCE_COPY,
51      'SOURCE_BSDIFF': common.OpType.SOURCE_BSDIFF,
52      'ZERO': common.OpType.ZERO,
53      'DISCARD': common.OpType.DISCARD,
54      'REPLACE_XZ': common.OpType.REPLACE_XZ,
55      'PUFFDIFF': common.OpType.PUFFDIFF,
56      'BROTLI_BSDIFF': common.OpType.BROTLI_BSDIFF,
57  }
58  return op_name_to_type[op_name]
59
60
61def _GetPayloadChecker(payload_gen_write_to_file_func, payload_gen_dargs=None,
62                       checker_init_dargs=None):
63  """Returns a payload checker from a given payload generator."""
64  if payload_gen_dargs is None:
65    payload_gen_dargs = {}
66  if checker_init_dargs is None:
67    checker_init_dargs = {}
68
69  payload_file = io.BytesIO()
70  payload_gen_write_to_file_func(payload_file, **payload_gen_dargs)
71  payload_file.seek(0)
72  payload = Payload(payload_file)
73  payload.Init()
74  return checker.PayloadChecker(payload, **checker_init_dargs)
75
76
77def _GetPayloadCheckerWithData(payload_gen):
78  """Returns a payload checker from a given payload generator."""
79  payload_file = io.BytesIO()
80  payload_gen.WriteToFile(payload_file)
81  payload_file.seek(0)
82  payload = Payload(payload_file)
83  payload.Init()
84  return checker.PayloadChecker(payload)
85
86
87# This class doesn't need an __init__().
88# pylint: disable=W0232
89# Unit testing is all about running protected methods.
90# pylint: disable=W0212
91# Don't bark about missing members of classes you cannot import.
92# pylint: disable=E1101
93class PayloadCheckerTest(unittest.TestCase):
94  """Tests the PayloadChecker class.
95
96  In addition to ordinary testFoo() methods, which are automatically invoked by
97  the unittest framework, in this class we make use of DoBarTest() calls that
98  implement parametric tests of certain features. In order to invoke each test,
99  which embodies a unique combination of parameter values, as a complete unit
100  test, we perform explicit enumeration of the parameter space and create
101  individual invocation contexts for each, which are then bound as
102  testBar__param1=val1__param2=val2(). The enumeration of parameter spaces for
103  all such tests is done in AddAllParametricTests().
104  """
105
106  def setUp(self):
107    """setUp function for unittest testcase"""
108    self.mock_checks = []
109
110  def tearDown(self):
111    """tearDown function for unittest testcase"""
112    # Verify that all mock functions were called.
113    for check in self.mock_checks:
114      check.mock_fn.assert_called_once_with(*check.exp_args, **check.exp_kwargs)
115
116  class MockChecksAtTearDown(object):
117    """Mock data storage.
118
119    This class stores the mock functions and its arguments to be checked at a
120    later point.
121    """
122    def __init__(self, mock_fn, *args, **kwargs):
123      self.mock_fn = mock_fn
124      self.exp_args = args
125      self.exp_kwargs = kwargs
126
127  def addPostCheckForMockFunction(self, mock_fn, *args, **kwargs):
128    """Store a mock function and its arguments to self.mock_checks
129
130    Args:
131      mock_fn: mock function object
132      args: expected positional arguments for the mock_fn
133      kwargs: expected named arguments for the mock_fn
134    """
135    self.mock_checks.append(self.MockChecksAtTearDown(mock_fn, *args, **kwargs))
136
137  def MockPayload(self):
138    """Create a mock payload object, complete with a mock manifest."""
139    payload = mock.create_autospec(Payload)
140    payload.is_init = True
141    payload.manifest = mock.create_autospec(
142        update_metadata_pb2.DeltaArchiveManifest)
143    return payload
144
145  @staticmethod
146  def NewExtent(start_block, num_blocks):
147    """Returns an Extent message.
148
149    Each of the provided fields is set iff it is >= 0; otherwise, it's left at
150    its default state.
151
152    Args:
153      start_block: The starting block of the extent.
154      num_blocks: The number of blocks in the extent.
155
156    Returns:
157      An Extent message.
158    """
159    ex = update_metadata_pb2.Extent()
160    if start_block >= 0:
161      ex.start_block = start_block
162    if num_blocks >= 0:
163      ex.num_blocks = num_blocks
164    return ex
165
166  @staticmethod
167  def NewExtentList(*args):
168    """Returns an list of extents.
169
170    Args:
171      *args: (start_block, num_blocks) pairs defining the extents.
172
173    Returns:
174      A list of Extent objects.
175    """
176    ex_list = []
177    for start_block, num_blocks in args:
178      ex_list.append(PayloadCheckerTest.NewExtent(start_block, num_blocks))
179    return ex_list
180
181  @staticmethod
182  def AddToMessage(repeated_field, field_vals):
183    for field_val in field_vals:
184      new_field = repeated_field.add()
185      new_field.CopyFrom(field_val)
186
187  def SetupAddElemTest(self, is_present, is_submsg, convert=str,
188                       linebreak=False, indent=0):
189    """Setup for testing of _CheckElem() and its derivatives.
190
191    Args:
192      is_present: Whether or not the element is found in the message.
193      is_submsg: Whether the element is a sub-message itself.
194      convert: A representation conversion function.
195      linebreak: Whether or not a linebreak is to be used in the report.
196      indent: Indentation used for the report.
197
198    Returns:
199      msg: A mock message object.
200      report: A mock report object.
201      subreport: A mock sub-report object.
202      name: An element name to check.
203      val: Expected element value.
204    """
205    name = 'foo'
206    val = 'fake submsg' if is_submsg else 'fake field'
207    subreport = 'fake subreport'
208
209    # Create a mock message.
210    msg = mock.create_autospec(update_metadata_pb2._message.Message)
211    self.addPostCheckForMockFunction(msg.HasField, name)
212    msg.HasField.return_value = is_present
213    setattr(msg, name, val)
214    # Create a mock report.
215    report = mock.create_autospec(checker._PayloadReport)
216    if is_present:
217      if is_submsg:
218        self.addPostCheckForMockFunction(report.AddSubReport, name)
219        report.AddSubReport.return_value = subreport
220      else:
221        self.addPostCheckForMockFunction(report.AddField, name, convert(val),
222                                         linebreak=linebreak, indent=indent)
223
224    return (msg, report, subreport, name, val)
225
226  def DoAddElemTest(self, is_present, is_mandatory, is_submsg, convert,
227                    linebreak, indent):
228    """Parametric testing of _CheckElem().
229
230    Args:
231      is_present: Whether or not the element is found in the message.
232      is_mandatory: Whether or not it's a mandatory element.
233      is_submsg: Whether the element is a sub-message itself.
234      convert: A representation conversion function.
235      linebreak: Whether or not a linebreak is to be used in the report.
236      indent: Indentation used for the report.
237    """
238    msg, report, subreport, name, val = self.SetupAddElemTest(
239        is_present, is_submsg, convert, linebreak, indent)
240
241    args = (msg, name, report, is_mandatory, is_submsg)
242    kwargs = {'convert': convert, 'linebreak': linebreak, 'indent': indent}
243    if is_mandatory and not is_present:
244      self.assertRaises(PayloadError,
245                        checker.PayloadChecker._CheckElem, *args, **kwargs)
246    else:
247      ret_val, ret_subreport = checker.PayloadChecker._CheckElem(*args,
248                                                                 **kwargs)
249      self.assertEqual(val if is_present else None, ret_val)
250      self.assertEqual(subreport if is_present and is_submsg else None,
251                       ret_subreport)
252
253  def DoAddFieldTest(self, is_mandatory, is_present, convert, linebreak,
254                     indent):
255    """Parametric testing of _Check{Mandatory,Optional}Field().
256
257    Args:
258      is_mandatory: Whether we're testing a mandatory call.
259      is_present: Whether or not the element is found in the message.
260      convert: A representation conversion function.
261      linebreak: Whether or not a linebreak is to be used in the report.
262      indent: Indentation used for the report.
263    """
264    msg, report, _, name, val = self.SetupAddElemTest(
265        is_present, False, convert, linebreak, indent)
266
267    # Prepare for invocation of the tested method.
268    args = [msg, name, report]
269    kwargs = {'convert': convert, 'linebreak': linebreak, 'indent': indent}
270    if is_mandatory:
271      args.append('bar')
272      tested_func = checker.PayloadChecker._CheckMandatoryField
273    else:
274      tested_func = checker.PayloadChecker._CheckOptionalField
275
276    # Test the method call.
277    if is_mandatory and not is_present:
278      self.assertRaises(PayloadError, tested_func, *args, **kwargs)
279    else:
280      ret_val = tested_func(*args, **kwargs)
281      self.assertEqual(val if is_present else None, ret_val)
282
283  def DoAddSubMsgTest(self, is_mandatory, is_present):
284    """Parametrized testing of _Check{Mandatory,Optional}SubMsg().
285
286    Args:
287      is_mandatory: Whether we're testing a mandatory call.
288      is_present: Whether or not the element is found in the message.
289    """
290    msg, report, subreport, name, val = self.SetupAddElemTest(is_present, True)
291
292    # Prepare for invocation of the tested method.
293    args = [msg, name, report]
294    if is_mandatory:
295      args.append('bar')
296      tested_func = checker.PayloadChecker._CheckMandatorySubMsg
297    else:
298      tested_func = checker.PayloadChecker._CheckOptionalSubMsg
299
300    # Test the method call.
301    if is_mandatory and not is_present:
302      self.assertRaises(PayloadError, tested_func, *args)
303    else:
304      ret_val, ret_subreport = tested_func(*args)
305      self.assertEqual(val if is_present else None, ret_val)
306      self.assertEqual(subreport if is_present else None, ret_subreport)
307
308  def testCheckPresentIff(self):
309    """Tests _CheckPresentIff()."""
310    self.assertIsNone(checker.PayloadChecker._CheckPresentIff(
311        None, None, 'foo', 'bar', 'baz'))
312    self.assertIsNone(checker.PayloadChecker._CheckPresentIff(
313        'a', 'b', 'foo', 'bar', 'baz'))
314    self.assertRaises(PayloadError, checker.PayloadChecker._CheckPresentIff,
315                      'a', None, 'foo', 'bar', 'baz')
316    self.assertRaises(PayloadError, checker.PayloadChecker._CheckPresentIff,
317                      None, 'b', 'foo', 'bar', 'baz')
318
319  def DoCheckSha256SignatureTest(self, expect_pass, expect_subprocess_call,
320                                 sig_data, sig_asn1_header,
321                                 returned_signed_hash, expected_signed_hash):
322    """Parametric testing of _CheckSha256SignatureTest().
323
324    Args:
325      expect_pass: Whether or not it should pass.
326      expect_subprocess_call: Whether to expect the openssl call to happen.
327      sig_data: The signature raw data.
328      sig_asn1_header: The ASN1 header.
329      returned_signed_hash: The signed hash data retuned by openssl.
330      expected_signed_hash: The signed hash data to compare against.
331    """
332    # Stub out the subprocess invocation.
333    with mock.patch.object(checker.PayloadChecker, '_Run') \
334         as mock_payload_checker:
335      if expect_subprocess_call:
336        mock_payload_checker([], send_data=sig_data)
337        mock_payload_checker.return_value = (
338            sig_asn1_header + returned_signed_hash, None)
339
340      if expect_pass:
341        self.assertIsNone(checker.PayloadChecker._CheckSha256Signature(
342            sig_data, 'foo', expected_signed_hash, 'bar'))
343      else:
344        self.assertRaises(PayloadError,
345                          checker.PayloadChecker._CheckSha256Signature,
346                          sig_data, 'foo', expected_signed_hash, 'bar')
347
348  def testCheckSha256Signature_Pass(self):
349    """Tests _CheckSha256Signature(); pass case."""
350    sig_data = 'fake-signature'.ljust(256)
351    signed_hash = hashlib.sha256(b'fake-data').digest()
352    self.DoCheckSha256SignatureTest(True, True, sig_data,
353                                    common.SIG_ASN1_HEADER, signed_hash,
354                                    signed_hash)
355
356  def testCheckSha256Signature_FailBadSignature(self):
357    """Tests _CheckSha256Signature(); fails due to malformed signature."""
358    sig_data = 'fake-signature'  # Malformed (not 256 bytes in length).
359    signed_hash = hashlib.sha256(b'fake-data').digest()
360    self.DoCheckSha256SignatureTest(False, False, sig_data,
361                                    common.SIG_ASN1_HEADER, signed_hash,
362                                    signed_hash)
363
364  def testCheckSha256Signature_FailBadOutputLength(self):
365    """Tests _CheckSha256Signature(); fails due to unexpected output length."""
366    sig_data = 'fake-signature'.ljust(256)
367    signed_hash = b'fake-hash'  # Malformed (not 32 bytes in length).
368    self.DoCheckSha256SignatureTest(False, True, sig_data,
369                                    common.SIG_ASN1_HEADER, signed_hash,
370                                    signed_hash)
371
372  def testCheckSha256Signature_FailBadAsnHeader(self):
373    """Tests _CheckSha256Signature(); fails due to bad ASN1 header."""
374    sig_data = 'fake-signature'.ljust(256)
375    signed_hash = hashlib.sha256(b'fake-data').digest()
376    bad_asn1_header = b'bad-asn-header'.ljust(len(common.SIG_ASN1_HEADER))
377    self.DoCheckSha256SignatureTest(False, True, sig_data, bad_asn1_header,
378                                    signed_hash, signed_hash)
379
380  def testCheckSha256Signature_FailBadHash(self):
381    """Tests _CheckSha256Signature(); fails due to bad hash returned."""
382    sig_data = 'fake-signature'.ljust(256)
383    expected_signed_hash = hashlib.sha256(b'fake-data').digest()
384    returned_signed_hash = hashlib.sha256(b'bad-fake-data').digest()
385    self.DoCheckSha256SignatureTest(False, True, sig_data,
386                                    common.SIG_ASN1_HEADER,
387                                    expected_signed_hash, returned_signed_hash)
388
389  def testCheckBlocksFitLength_Pass(self):
390    """Tests _CheckBlocksFitLength(); pass case."""
391    self.assertIsNone(checker.PayloadChecker._CheckBlocksFitLength(
392        64, 4, 16, 'foo'))
393    self.assertIsNone(checker.PayloadChecker._CheckBlocksFitLength(
394        60, 4, 16, 'foo'))
395    self.assertIsNone(checker.PayloadChecker._CheckBlocksFitLength(
396        49, 4, 16, 'foo'))
397    self.assertIsNone(checker.PayloadChecker._CheckBlocksFitLength(
398        48, 3, 16, 'foo'))
399
400  def testCheckBlocksFitLength_TooManyBlocks(self):
401    """Tests _CheckBlocksFitLength(); fails due to excess blocks."""
402    self.assertRaises(PayloadError,
403                      checker.PayloadChecker._CheckBlocksFitLength,
404                      64, 5, 16, 'foo')
405    self.assertRaises(PayloadError,
406                      checker.PayloadChecker._CheckBlocksFitLength,
407                      60, 5, 16, 'foo')
408    self.assertRaises(PayloadError,
409                      checker.PayloadChecker._CheckBlocksFitLength,
410                      49, 5, 16, 'foo')
411    self.assertRaises(PayloadError,
412                      checker.PayloadChecker._CheckBlocksFitLength,
413                      48, 4, 16, 'foo')
414
415  def testCheckBlocksFitLength_TooFewBlocks(self):
416    """Tests _CheckBlocksFitLength(); fails due to insufficient blocks."""
417    self.assertRaises(PayloadError,
418                      checker.PayloadChecker._CheckBlocksFitLength,
419                      64, 3, 16, 'foo')
420    self.assertRaises(PayloadError,
421                      checker.PayloadChecker._CheckBlocksFitLength,
422                      60, 3, 16, 'foo')
423    self.assertRaises(PayloadError,
424                      checker.PayloadChecker._CheckBlocksFitLength,
425                      49, 3, 16, 'foo')
426    self.assertRaises(PayloadError,
427                      checker.PayloadChecker._CheckBlocksFitLength,
428                      48, 2, 16, 'foo')
429
430  def DoCheckManifestTest(self, fail_mismatched_block_size, fail_bad_sigs,
431                          fail_mismatched_oki_ori, fail_bad_oki, fail_bad_ori,
432                          fail_bad_nki, fail_bad_nri, fail_old_kernel_fs_size,
433                          fail_old_rootfs_fs_size, fail_new_kernel_fs_size,
434                          fail_new_rootfs_fs_size):
435    """Parametric testing of _CheckManifest().
436
437    Args:
438      fail_mismatched_block_size: Simulate a missing block_size field.
439      fail_bad_sigs: Make signatures descriptor inconsistent.
440      fail_mismatched_oki_ori: Make old rootfs/kernel info partially present.
441      fail_bad_oki: Tamper with old kernel info.
442      fail_bad_ori: Tamper with old rootfs info.
443      fail_bad_nki: Tamper with new kernel info.
444      fail_bad_nri: Tamper with new rootfs info.
445      fail_old_kernel_fs_size: Make old kernel fs size too big.
446      fail_old_rootfs_fs_size: Make old rootfs fs size too big.
447      fail_new_kernel_fs_size: Make new kernel fs size too big.
448      fail_new_rootfs_fs_size: Make new rootfs fs size too big.
449    """
450    # Generate a test payload. For this test, we only care about the manifest
451    # and don't need any data blobs, hence we can use a plain paylaod generator
452    # (which also gives us more control on things that can be screwed up).
453    payload_gen = test_utils.PayloadGenerator()
454
455    # Tamper with block size, if required.
456    if fail_mismatched_block_size:
457      payload_gen.SetBlockSize(test_utils.KiB(1))
458    else:
459      payload_gen.SetBlockSize(test_utils.KiB(4))
460
461    # Add some operations.
462    payload_gen.AddOperation(common.ROOTFS, common.OpType.SOURCE_COPY,
463                             src_extents=[(0, 16), (16, 497)],
464                             dst_extents=[(16, 496), (0, 16)])
465    payload_gen.AddOperation(common.KERNEL, common.OpType.SOURCE_COPY,
466                             src_extents=[(0, 8), (8, 8)],
467                             dst_extents=[(8, 8), (0, 8)])
468
469    # Set an invalid signatures block (offset but no size), if required.
470    if fail_bad_sigs:
471      payload_gen.SetSignatures(32, None)
472
473    # Set partition / filesystem sizes.
474    rootfs_part_size = test_utils.MiB(8)
475    kernel_part_size = test_utils.KiB(512)
476    old_rootfs_fs_size = new_rootfs_fs_size = rootfs_part_size
477    old_kernel_fs_size = new_kernel_fs_size = kernel_part_size
478    if fail_old_kernel_fs_size:
479      old_kernel_fs_size += 100
480    if fail_old_rootfs_fs_size:
481      old_rootfs_fs_size += 100
482    if fail_new_kernel_fs_size:
483      new_kernel_fs_size += 100
484    if fail_new_rootfs_fs_size:
485      new_rootfs_fs_size += 100
486
487    # Add old kernel/rootfs partition info, as required.
488    if fail_mismatched_oki_ori or fail_old_kernel_fs_size or fail_bad_oki:
489      oki_hash = (None if fail_bad_oki
490                  else hashlib.sha256(b'fake-oki-content').digest())
491      payload_gen.SetPartInfo(common.KERNEL, False, old_kernel_fs_size,
492                              oki_hash)
493    if not fail_mismatched_oki_ori and (fail_old_rootfs_fs_size or
494                                        fail_bad_ori):
495      ori_hash = (None if fail_bad_ori
496                  else hashlib.sha256(b'fake-ori-content').digest())
497      payload_gen.SetPartInfo(common.ROOTFS, False, old_rootfs_fs_size,
498                              ori_hash)
499
500    # Add new kernel/rootfs partition info.
501    payload_gen.SetPartInfo(
502        common.KERNEL, True, new_kernel_fs_size,
503        None if fail_bad_nki else hashlib.sha256(b'fake-nki-content').digest())
504    payload_gen.SetPartInfo(
505        common.ROOTFS, True, new_rootfs_fs_size,
506        None if fail_bad_nri else hashlib.sha256(b'fake-nri-content').digest())
507
508    # Set the minor version.
509    payload_gen.SetMinorVersion(0)
510
511    # Create the test object.
512    payload_checker = _GetPayloadChecker(payload_gen.WriteToFile)
513    report = checker._PayloadReport()
514
515    should_fail = (fail_mismatched_block_size or fail_bad_sigs or
516                   fail_mismatched_oki_ori or fail_bad_oki or fail_bad_ori or
517                   fail_bad_nki or fail_bad_nri or fail_old_kernel_fs_size or
518                   fail_old_rootfs_fs_size or fail_new_kernel_fs_size or
519                   fail_new_rootfs_fs_size)
520    part_sizes = {
521        common.ROOTFS: rootfs_part_size,
522        common.KERNEL: kernel_part_size
523    }
524
525    if should_fail:
526      self.assertRaises(PayloadError, payload_checker._CheckManifest, report,
527                        part_sizes)
528    else:
529      self.assertIsNone(payload_checker._CheckManifest(report, part_sizes))
530
531  def testCheckLength(self):
532    """Tests _CheckLength()."""
533    payload_checker = checker.PayloadChecker(self.MockPayload())
534    block_size = payload_checker.block_size
535
536    # Passes.
537    self.assertIsNone(payload_checker._CheckLength(
538        int(3.5 * block_size), 4, 'foo', 'bar'))
539    # Fails, too few blocks.
540    self.assertRaises(PayloadError, payload_checker._CheckLength,
541                      int(3.5 * block_size), 3, 'foo', 'bar')
542    # Fails, too many blocks.
543    self.assertRaises(PayloadError, payload_checker._CheckLength,
544                      int(3.5 * block_size), 5, 'foo', 'bar')
545
546  def testCheckExtents(self):
547    """Tests _CheckExtents()."""
548    payload_checker = checker.PayloadChecker(self.MockPayload())
549    block_size = payload_checker.block_size
550
551    # Passes w/ all real extents.
552    extents = self.NewExtentList((0, 4), (8, 3), (1024, 16))
553    self.assertEqual(
554        23,
555        payload_checker._CheckExtents(extents, (1024 + 16) * block_size,
556                                      collections.defaultdict(int), 'foo'))
557
558    # Fails, extent missing a start block.
559    extents = self.NewExtentList((-1, 4), (8, 3), (1024, 16))
560    self.assertRaises(
561        PayloadError, payload_checker._CheckExtents, extents,
562        (1024 + 16) * block_size, collections.defaultdict(int), 'foo')
563
564    # Fails, extent missing block count.
565    extents = self.NewExtentList((0, -1), (8, 3), (1024, 16))
566    self.assertRaises(
567        PayloadError, payload_checker._CheckExtents, extents,
568        (1024 + 16) * block_size, collections.defaultdict(int), 'foo')
569
570    # Fails, extent has zero blocks.
571    extents = self.NewExtentList((0, 4), (8, 3), (1024, 0))
572    self.assertRaises(
573        PayloadError, payload_checker._CheckExtents, extents,
574        (1024 + 16) * block_size, collections.defaultdict(int), 'foo')
575
576    # Fails, extent exceeds partition boundaries.
577    extents = self.NewExtentList((0, 4), (8, 3), (1024, 16))
578    self.assertRaises(
579        PayloadError, payload_checker._CheckExtents, extents,
580        (1024 + 15) * block_size, collections.defaultdict(int), 'foo')
581
582  def testCheckReplaceOperation(self):
583    """Tests _CheckReplaceOperation() where op.type == REPLACE."""
584    payload_checker = checker.PayloadChecker(self.MockPayload())
585    block_size = payload_checker.block_size
586    data_length = 10000
587
588    op = mock.create_autospec(update_metadata_pb2.InstallOperation)
589    op.type = common.OpType.REPLACE
590
591    # Pass.
592    op.src_extents = []
593    self.assertIsNone(
594        payload_checker._CheckReplaceOperation(
595            op, data_length, (data_length + block_size - 1) // block_size,
596            'foo'))
597
598    # Fail, src extents founds.
599    op.src_extents = ['bar']
600    self.assertRaises(
601        PayloadError, payload_checker._CheckReplaceOperation,
602        op, data_length, (data_length + block_size - 1) // block_size, 'foo')
603
604    # Fail, missing data.
605    op.src_extents = []
606    self.assertRaises(
607        PayloadError, payload_checker._CheckReplaceOperation,
608        op, None, (data_length + block_size - 1) // block_size, 'foo')
609
610    # Fail, length / block number mismatch.
611    op.src_extents = ['bar']
612    self.assertRaises(
613        PayloadError, payload_checker._CheckReplaceOperation,
614        op, data_length, (data_length + block_size - 1) // block_size + 1,
615        'foo')
616
617  def testCheckReplaceBzOperation(self):
618    """Tests _CheckReplaceOperation() where op.type == REPLACE_BZ."""
619    payload_checker = checker.PayloadChecker(self.MockPayload())
620    block_size = payload_checker.block_size
621    data_length = block_size * 3
622
623    op = mock.create_autospec(
624        update_metadata_pb2.InstallOperation)
625    op.type = common.OpType.REPLACE_BZ
626
627    # Pass.
628    op.src_extents = []
629    self.assertIsNone(
630        payload_checker._CheckReplaceOperation(
631            op, data_length, (data_length + block_size - 1) // block_size + 5,
632            'foo'))
633
634    # Fail, src extents founds.
635    op.src_extents = ['bar']
636    self.assertRaises(
637        PayloadError, payload_checker._CheckReplaceOperation,
638        op, data_length, (data_length + block_size - 1) // block_size + 5,
639        'foo')
640
641    # Fail, missing data.
642    op.src_extents = []
643    self.assertRaises(
644        PayloadError, payload_checker._CheckReplaceOperation,
645        op, None, (data_length + block_size - 1) // block_size, 'foo')
646
647    # Fail, too few blocks to justify BZ.
648    op.src_extents = []
649    self.assertRaises(
650        PayloadError, payload_checker._CheckReplaceOperation,
651        op, data_length, (data_length + block_size - 1) // block_size, 'foo')
652
653    # Fail, total_dst_blocks is a floating point value.
654    op.src_extents = []
655    self.assertRaises(
656        PayloadError, payload_checker._CheckReplaceOperation,
657        op, data_length, (data_length + block_size - 1) / block_size, 'foo')
658
659  def testCheckReplaceXzOperation(self):
660    """Tests _CheckReplaceOperation() where op.type == REPLACE_XZ."""
661    payload_checker = checker.PayloadChecker(self.MockPayload())
662    block_size = payload_checker.block_size
663    data_length = block_size * 3
664
665    op = mock.create_autospec(
666        update_metadata_pb2.InstallOperation)
667    op.type = common.OpType.REPLACE_XZ
668
669    # Pass.
670    op.src_extents = []
671    self.assertIsNone(
672        payload_checker._CheckReplaceOperation(
673            op, data_length, (data_length + block_size - 1) // block_size + 5,
674            'foo'))
675
676    # Fail, src extents founds.
677    op.src_extents = ['bar']
678    self.assertRaises(
679        PayloadError, payload_checker._CheckReplaceOperation,
680        op, data_length, (data_length + block_size - 1) // block_size + 5,
681        'foo')
682
683    # Fail, missing data.
684    op.src_extents = []
685    self.assertRaises(
686        PayloadError, payload_checker._CheckReplaceOperation,
687        op, None, (data_length + block_size - 1) // block_size, 'foo')
688
689    # Fail, too few blocks to justify XZ.
690    op.src_extents = []
691    self.assertRaises(
692        PayloadError, payload_checker._CheckReplaceOperation,
693        op, data_length, (data_length + block_size - 1) // block_size, 'foo')
694
695    # Fail, total_dst_blocks is a floating point value.
696    op.src_extents = []
697    self.assertRaises(
698        PayloadError, payload_checker._CheckReplaceOperation,
699        op, data_length, (data_length + block_size - 1) / block_size, 'foo')
700
701  def testCheckAnyDiff(self):
702    """Tests _CheckAnyDiffOperation()."""
703    payload_checker = checker.PayloadChecker(self.MockPayload())
704    op = update_metadata_pb2.InstallOperation()
705
706    # Pass.
707    self.assertIsNone(
708        payload_checker._CheckAnyDiffOperation(op, 10000, 3, 'foo'))
709
710    # Fail, missing data blob.
711    self.assertRaises(
712        PayloadError, payload_checker._CheckAnyDiffOperation,
713        op, None, 3, 'foo')
714
715    # Fail, too big of a diff blob (unjustified).
716    self.assertRaises(
717        PayloadError, payload_checker._CheckAnyDiffOperation,
718        op, 10000, 2, 'foo')
719
720  def testCheckSourceCopyOperation_Pass(self):
721    """Tests _CheckSourceCopyOperation(); pass case."""
722    payload_checker = checker.PayloadChecker(self.MockPayload())
723    self.assertIsNone(
724        payload_checker._CheckSourceCopyOperation(None, 134, 134, 'foo'))
725
726  def testCheckSourceCopyOperation_FailContainsData(self):
727    """Tests _CheckSourceCopyOperation(); message contains data."""
728    payload_checker = checker.PayloadChecker(self.MockPayload())
729    self.assertRaises(PayloadError, payload_checker._CheckSourceCopyOperation,
730                      134, 0, 0, 'foo')
731
732  def testCheckSourceCopyOperation_FailBlockCountsMismatch(self):
733    """Tests _CheckSourceCopyOperation(); src and dst block totals not equal."""
734    payload_checker = checker.PayloadChecker(self.MockPayload())
735    self.assertRaises(PayloadError, payload_checker._CheckSourceCopyOperation,
736                      None, 0, 1, 'foo')
737
738  def DoCheckOperationTest(self, op_type_name, allow_unhashed,
739                           fail_src_extents, fail_dst_extents,
740                           fail_mismatched_data_offset_length,
741                           fail_missing_dst_extents, fail_src_length,
742                           fail_dst_length, fail_data_hash,
743                           fail_prev_data_offset, fail_bad_minor_version):
744    """Parametric testing of _CheckOperation().
745
746    Args:
747      op_type_name: 'REPLACE', 'REPLACE_BZ', 'REPLACE_XZ',
748        'SOURCE_COPY', 'SOURCE_BSDIFF', BROTLI_BSDIFF or 'PUFFDIFF'.
749      allow_unhashed: Whether we're allowing to not hash the data.
750      fail_src_extents: Tamper with src extents.
751      fail_dst_extents: Tamper with dst extents.
752      fail_mismatched_data_offset_length: Make data_{offset,length}
753        inconsistent.
754      fail_missing_dst_extents: Do not include dst extents.
755      fail_src_length: Make src length inconsistent.
756      fail_dst_length: Make dst length inconsistent.
757      fail_data_hash: Tamper with the data blob hash.
758      fail_prev_data_offset: Make data space uses incontiguous.
759      fail_bad_minor_version: Make minor version incompatible with op.
760    """
761    op_type = _OpTypeByName(op_type_name)
762
763    # Create the test object.
764    payload = self.MockPayload()
765    payload_checker = checker.PayloadChecker(payload,
766                                             allow_unhashed=allow_unhashed)
767    block_size = payload_checker.block_size
768
769    # Create auxiliary arguments.
770    old_part_size = test_utils.MiB(4)
771    new_part_size = test_utils.MiB(8)
772    old_block_counters = array.array(
773        'B', [0] * ((old_part_size + block_size - 1) // block_size))
774    new_block_counters = array.array(
775        'B', [0] * ((new_part_size + block_size - 1) // block_size))
776    prev_data_offset = 1876
777    blob_hash_counts = collections.defaultdict(int)
778
779    # Create the operation object for the test.
780    op = update_metadata_pb2.InstallOperation()
781    op.type = op_type
782
783    total_src_blocks = 0
784    if op_type in (common.OpType.SOURCE_COPY, common.OpType.SOURCE_BSDIFF,
785                   common.OpType.PUFFDIFF, common.OpType.BROTLI_BSDIFF):
786      if fail_src_extents:
787        self.AddToMessage(op.src_extents,
788                          self.NewExtentList((1, 0)))
789      else:
790        self.AddToMessage(op.src_extents,
791                          self.NewExtentList((1, 16)))
792        total_src_blocks = 16
793
794    payload_checker.major_version = common.BRILLO_MAJOR_PAYLOAD_VERSION
795    if op_type in (common.OpType.REPLACE, common.OpType.REPLACE_BZ):
796      payload_checker.minor_version = 0
797    elif op_type in (common.OpType.SOURCE_COPY, common.OpType.SOURCE_BSDIFF):
798      payload_checker.minor_version = 1 if fail_bad_minor_version else 2
799    if op_type == common.OpType.REPLACE_XZ:
800      payload_checker.minor_version = 2 if fail_bad_minor_version else 3
801    elif op_type in (common.OpType.ZERO, common.OpType.DISCARD,
802                     common.OpType.BROTLI_BSDIFF):
803      payload_checker.minor_version = 3 if fail_bad_minor_version else 4
804    elif op_type == common.OpType.PUFFDIFF:
805      payload_checker.minor_version = 4 if fail_bad_minor_version else 5
806
807    if op_type != common.OpType.SOURCE_COPY:
808      if not fail_mismatched_data_offset_length:
809        op.data_length = 16 * block_size - 8
810      if fail_prev_data_offset:
811        op.data_offset = prev_data_offset + 16
812      else:
813        op.data_offset = prev_data_offset
814
815      fake_data = 'fake-data'.ljust(op.data_length)
816      if not allow_unhashed and not fail_data_hash:
817        # Create a valid data blob hash.
818        op.data_sha256_hash = hashlib.sha256(fake_data.encode('utf-8')).digest()
819        payload.ReadDataBlob.return_value = fake_data.encode('utf-8')
820
821      elif fail_data_hash:
822        # Create an invalid data blob hash.
823        op.data_sha256_hash = hashlib.sha256(
824            fake_data.replace(' ', '-').encode('utf-8')).digest()
825        payload.ReadDataBlob.return_value = fake_data.encode('utf-8')
826
827    total_dst_blocks = 0
828    if not fail_missing_dst_extents:
829      total_dst_blocks = 16
830      if fail_dst_extents:
831        self.AddToMessage(op.dst_extents,
832                          self.NewExtentList((4, 16), (32, 0)))
833      else:
834        self.AddToMessage(op.dst_extents,
835                          self.NewExtentList((4, 8), (64, 8)))
836
837    if total_src_blocks:
838      if fail_src_length:
839        op.src_length = total_src_blocks * block_size + 8
840      elif (op_type == common.OpType.SOURCE_BSDIFF and
841            payload_checker.minor_version <= 3):
842        op.src_length = total_src_blocks * block_size
843    elif fail_src_length:
844      # Add an orphaned src_length.
845      op.src_length = 16
846
847    if total_dst_blocks:
848      if fail_dst_length:
849        op.dst_length = total_dst_blocks * block_size + 8
850      elif (op_type == common.OpType.SOURCE_BSDIFF and
851            payload_checker.minor_version <= 3):
852        op.dst_length = total_dst_blocks * block_size
853
854    should_fail = (fail_src_extents or fail_dst_extents or
855                   fail_mismatched_data_offset_length or
856                   fail_missing_dst_extents or fail_src_length or
857                   fail_dst_length or fail_data_hash or fail_prev_data_offset or
858                   fail_bad_minor_version)
859    args = (op, 'foo', old_block_counters, new_block_counters,
860            old_part_size, new_part_size, prev_data_offset,
861            blob_hash_counts)
862    if should_fail:
863      self.assertRaises(PayloadError, payload_checker._CheckOperation, *args)
864    else:
865      self.assertEqual(op.data_length if op.HasField('data_length') else 0,
866                       payload_checker._CheckOperation(*args))
867
868  def testAllocBlockCounters(self):
869    """Tests _CheckMoveOperation()."""
870    payload_checker = checker.PayloadChecker(self.MockPayload())
871    block_size = payload_checker.block_size
872
873    # Check allocation for block-aligned partition size, ensure it's integers.
874    result = payload_checker._AllocBlockCounters(16 * block_size)
875    self.assertEqual(16, len(result))
876    self.assertEqual(int, type(result[0]))
877
878    # Check allocation of unaligned partition sizes.
879    result = payload_checker._AllocBlockCounters(16 * block_size - 1)
880    self.assertEqual(16, len(result))
881    result = payload_checker._AllocBlockCounters(16 * block_size + 1)
882    self.assertEqual(17, len(result))
883
884  def DoCheckOperationsTest(self, fail_nonexhaustive_full_update):
885    """Tests _CheckOperations()."""
886    # Generate a test payload. For this test, we only care about one
887    # (arbitrary) set of operations, so we'll only be generating kernel and
888    # test with them.
889    payload_gen = test_utils.PayloadGenerator()
890
891    block_size = test_utils.KiB(4)
892    payload_gen.SetBlockSize(block_size)
893
894    rootfs_part_size = test_utils.MiB(8)
895
896    # Fake rootfs operations in a full update, tampered with as required.
897    rootfs_op_type = common.OpType.REPLACE
898    rootfs_data_length = rootfs_part_size
899    if fail_nonexhaustive_full_update:
900      rootfs_data_length -= block_size
901
902    payload_gen.AddOperation(common.ROOTFS, rootfs_op_type,
903                             dst_extents=
904                             [(0, rootfs_data_length // block_size)],
905                             data_offset=0,
906                             data_length=rootfs_data_length)
907
908    # Create the test object.
909    payload_checker = _GetPayloadChecker(payload_gen.WriteToFile,
910                                         checker_init_dargs={
911                                             'allow_unhashed': True})
912    payload_checker.payload_type = checker._TYPE_FULL
913    report = checker._PayloadReport()
914    partition = next((p for p in payload_checker.payload.manifest.partitions
915                      if p.partition_name == common.ROOTFS), None)
916    args = (partition.operations, report, 'foo',
917            0, rootfs_part_size, rootfs_part_size, rootfs_part_size, 0)
918    if fail_nonexhaustive_full_update:
919      self.assertRaises(PayloadError, payload_checker._CheckOperations, *args)
920    else:
921      self.assertEqual(rootfs_data_length,
922                       payload_checker._CheckOperations(*args))
923
924  def DoCheckSignaturesTest(self, fail_empty_sigs_blob, fail_sig_missing_fields,
925                            fail_unknown_sig_version, fail_incorrect_sig):
926    """Tests _CheckSignatures()."""
927    # Generate a test payload. For this test, we only care about the signature
928    # block and how it relates to the payload hash. Therefore, we're generating
929    # a random (otherwise useless) payload for this purpose.
930    payload_gen = test_utils.EnhancedPayloadGenerator()
931    block_size = test_utils.KiB(4)
932    payload_gen.SetBlockSize(block_size)
933    rootfs_part_size = test_utils.MiB(2)
934    kernel_part_size = test_utils.KiB(16)
935    payload_gen.SetPartInfo(common.ROOTFS, True, rootfs_part_size,
936                            hashlib.sha256(b'fake-new-rootfs-content').digest())
937    payload_gen.SetPartInfo(common.KERNEL, True, kernel_part_size,
938                            hashlib.sha256(b'fake-new-kernel-content').digest())
939    payload_gen.SetMinorVersion(0)
940    payload_gen.AddOperationWithData(
941        common.ROOTFS, common.OpType.REPLACE,
942        dst_extents=[(0, rootfs_part_size // block_size)],
943        data_blob=os.urandom(rootfs_part_size))
944
945    do_forge_sigs_data = (fail_empty_sigs_blob or fail_sig_missing_fields or
946                          fail_unknown_sig_version or fail_incorrect_sig)
947
948    sigs_data = None
949    if do_forge_sigs_data:
950      sigs_gen = test_utils.SignaturesGenerator()
951      if not fail_empty_sigs_blob:
952        if fail_sig_missing_fields:
953          sig_data = None
954        else:
955          sig_data = test_utils.SignSha256(b'fake-payload-content',
956                                           test_utils._PRIVKEY_FILE_NAME)
957        sigs_gen.AddSig(5 if fail_unknown_sig_version else 1, sig_data)
958
959      sigs_data = sigs_gen.ToBinary()
960      payload_gen.SetSignatures(payload_gen.curr_offset, len(sigs_data))
961
962    # Generate payload (complete w/ signature) and create the test object.
963    payload_checker = _GetPayloadChecker(
964        payload_gen.WriteToFileWithData,
965        payload_gen_dargs={
966            'sigs_data': sigs_data,
967            'privkey_file_name': test_utils._PRIVKEY_FILE_NAME})
968    payload_checker.payload_type = checker._TYPE_FULL
969    report = checker._PayloadReport()
970
971    # We have to check the manifest first in order to set signature attributes.
972    payload_checker._CheckManifest(report, {
973        common.ROOTFS: rootfs_part_size,
974        common.KERNEL: kernel_part_size
975    })
976
977    should_fail = (fail_empty_sigs_blob or fail_sig_missing_fields or
978                   fail_unknown_sig_version or fail_incorrect_sig)
979    args = (report, test_utils._PUBKEY_FILE_NAME)
980    if should_fail:
981      self.assertRaises(PayloadError, payload_checker._CheckSignatures, *args)
982    else:
983      self.assertIsNone(payload_checker._CheckSignatures(*args))
984
985  def DoCheckManifestMinorVersionTest(self, minor_version, payload_type):
986    """Parametric testing for CheckManifestMinorVersion().
987
988    Args:
989      minor_version: The payload minor version to test with.
990      payload_type: The type of the payload we're testing, delta or full.
991    """
992    # Create the test object.
993    payload = self.MockPayload()
994    payload.manifest.minor_version = minor_version
995    payload_checker = checker.PayloadChecker(payload)
996    payload_checker.payload_type = payload_type
997    report = checker._PayloadReport()
998
999    should_succeed = (
1000        (minor_version == 0 and payload_type == checker._TYPE_FULL) or
1001        (minor_version == 2 and payload_type == checker._TYPE_DELTA) or
1002        (minor_version == 3 and payload_type == checker._TYPE_DELTA) or
1003        (minor_version == 4 and payload_type == checker._TYPE_DELTA) or
1004        (minor_version == 5 and payload_type == checker._TYPE_DELTA))
1005    args = (report,)
1006
1007    if should_succeed:
1008      self.assertIsNone(payload_checker._CheckManifestMinorVersion(*args))
1009    else:
1010      self.assertRaises(PayloadError,
1011                        payload_checker._CheckManifestMinorVersion, *args)
1012
1013  def DoRunTest(self, rootfs_part_size_provided, kernel_part_size_provided,
1014                fail_wrong_payload_type, fail_invalid_block_size,
1015                fail_mismatched_metadata_size, fail_mismatched_block_size,
1016                fail_excess_data, fail_rootfs_part_size_exceeded,
1017                fail_kernel_part_size_exceeded):
1018    """Tests Run()."""
1019    # Generate a test payload. For this test, we generate a full update that
1020    # has sample kernel and rootfs operations. Since most testing is done with
1021    # internal PayloadChecker methods that are tested elsewhere, here we only
1022    # tamper with what's actually being manipulated and/or tested in the Run()
1023    # method itself. Note that the checker doesn't verify partition hashes, so
1024    # they're safe to fake.
1025    payload_gen = test_utils.EnhancedPayloadGenerator()
1026    block_size = test_utils.KiB(4)
1027    payload_gen.SetBlockSize(block_size)
1028    kernel_filesystem_size = test_utils.KiB(16)
1029    rootfs_filesystem_size = test_utils.MiB(2)
1030    payload_gen.SetPartInfo(common.ROOTFS, True, rootfs_filesystem_size,
1031                            hashlib.sha256(b'fake-new-rootfs-content').digest())
1032    payload_gen.SetPartInfo(common.KERNEL, True, kernel_filesystem_size,
1033                            hashlib.sha256(b'fake-new-kernel-content').digest())
1034    payload_gen.SetMinorVersion(0)
1035
1036    rootfs_part_size = 0
1037    if rootfs_part_size_provided:
1038      rootfs_part_size = rootfs_filesystem_size + block_size
1039    rootfs_op_size = rootfs_part_size or rootfs_filesystem_size
1040    if fail_rootfs_part_size_exceeded:
1041      rootfs_op_size += block_size
1042    payload_gen.AddOperationWithData(
1043        common.ROOTFS, common.OpType.REPLACE,
1044        dst_extents=[(0, rootfs_op_size // block_size)],
1045        data_blob=os.urandom(rootfs_op_size))
1046
1047    kernel_part_size = 0
1048    if kernel_part_size_provided:
1049      kernel_part_size = kernel_filesystem_size + block_size
1050    kernel_op_size = kernel_part_size or kernel_filesystem_size
1051    if fail_kernel_part_size_exceeded:
1052      kernel_op_size += block_size
1053    payload_gen.AddOperationWithData(
1054        common.KERNEL, common.OpType.REPLACE,
1055        dst_extents=[(0, kernel_op_size // block_size)],
1056        data_blob=os.urandom(kernel_op_size))
1057
1058    # Generate payload (complete w/ signature) and create the test object.
1059    if fail_invalid_block_size:
1060      use_block_size = block_size + 5  # Not a power of two.
1061    elif fail_mismatched_block_size:
1062      use_block_size = block_size * 2  # Different that payload stated.
1063    else:
1064      use_block_size = block_size
1065
1066    # For the unittests 237 is the value that generated for the payload.
1067    metadata_size = 237
1068    if fail_mismatched_metadata_size:
1069      metadata_size += 1
1070
1071    kwargs = {
1072        'payload_gen_dargs': {
1073            'privkey_file_name': test_utils._PRIVKEY_FILE_NAME,
1074            'padding': os.urandom(1024) if fail_excess_data else None},
1075        'checker_init_dargs': {
1076            'assert_type': 'delta' if fail_wrong_payload_type else 'full',
1077            'block_size': use_block_size}}
1078    if fail_invalid_block_size:
1079      self.assertRaises(PayloadError, _GetPayloadChecker,
1080                        payload_gen.WriteToFileWithData, **kwargs)
1081    else:
1082      payload_checker = _GetPayloadChecker(payload_gen.WriteToFileWithData,
1083                                           **kwargs)
1084
1085      kwargs2 = {
1086          'pubkey_file_name': test_utils._PUBKEY_FILE_NAME,
1087          'metadata_size': metadata_size,
1088          'part_sizes': {
1089              common.KERNEL: kernel_part_size,
1090              common.ROOTFS: rootfs_part_size}}
1091
1092      should_fail = (fail_wrong_payload_type or fail_mismatched_block_size or
1093                     fail_mismatched_metadata_size or fail_excess_data or
1094                     fail_rootfs_part_size_exceeded or
1095                     fail_kernel_part_size_exceeded)
1096      if should_fail:
1097        self.assertRaises(PayloadError, payload_checker.Run, **kwargs2)
1098      else:
1099        self.assertIsNone(payload_checker.Run(**kwargs2))
1100
1101
1102# This implements a generic API, hence the occasional unused args.
1103# pylint: disable=W0613
1104def ValidateCheckOperationTest(op_type_name, allow_unhashed,
1105                               fail_src_extents, fail_dst_extents,
1106                               fail_mismatched_data_offset_length,
1107                               fail_missing_dst_extents, fail_src_length,
1108                               fail_dst_length, fail_data_hash,
1109                               fail_prev_data_offset, fail_bad_minor_version):
1110  """Returns True iff the combination of arguments represents a valid test."""
1111  op_type = _OpTypeByName(op_type_name)
1112
1113  # REPLACE/REPLACE_BZ/REPLACE_XZ operations don't read data from src
1114  # partition. They are compatible with all valid minor versions, so we don't
1115  # need to check that.
1116  if (op_type in (common.OpType.REPLACE, common.OpType.REPLACE_BZ,
1117                  common.OpType.REPLACE_XZ) and (fail_src_extents or
1118                                                 fail_src_length or
1119                                                 fail_bad_minor_version)):
1120    return False
1121
1122  # SOURCE_COPY operation does not carry data.
1123  if (op_type == common.OpType.SOURCE_COPY and (
1124      fail_mismatched_data_offset_length or fail_data_hash or
1125      fail_prev_data_offset)):
1126    return False
1127
1128  return True
1129
1130
1131def TestMethodBody(run_method_name, run_dargs):
1132  """Returns a function that invokes a named method with named arguments."""
1133  return lambda self: getattr(self, run_method_name)(**run_dargs)
1134
1135
1136def AddParametricTests(tested_method_name, arg_space, validate_func=None):
1137  """Enumerates and adds specific parametric tests to PayloadCheckerTest.
1138
1139  This function enumerates a space of test parameters (defined by arg_space),
1140  then binds a new, unique method name in PayloadCheckerTest to a test function
1141  that gets handed the said parameters. This is a preferable approach to doing
1142  the enumeration and invocation during the tests because this way each test is
1143  treated as a complete run by the unittest framework, and so benefits from the
1144  usual setUp/tearDown mechanics.
1145
1146  Args:
1147    tested_method_name: Name of the tested PayloadChecker method.
1148    arg_space: A dictionary containing variables (keys) and lists of values
1149               (values) associated with them.
1150    validate_func: A function used for validating test argument combinations.
1151  """
1152  for value_tuple in itertools.product(*iter(arg_space.values())):
1153    run_dargs = dict(zip(iter(arg_space.keys()), value_tuple))
1154    if validate_func and not validate_func(**run_dargs):
1155      continue
1156    run_method_name = 'Do%sTest' % tested_method_name
1157    test_method_name = 'test%s' % tested_method_name
1158    for arg_key, arg_val in run_dargs.items():
1159      if arg_val or isinstance(arg_val, int):
1160        test_method_name += '__%s=%s' % (arg_key, arg_val)
1161    setattr(PayloadCheckerTest, test_method_name,
1162            TestMethodBody(run_method_name, run_dargs))
1163
1164
1165def AddAllParametricTests():
1166  """Enumerates and adds all parametric tests to PayloadCheckerTest."""
1167  # Add all _CheckElem() test cases.
1168  AddParametricTests('AddElem',
1169                     {'linebreak': (True, False),
1170                      'indent': (0, 1, 2),
1171                      'convert': (str, lambda s: s[::-1]),
1172                      'is_present': (True, False),
1173                      'is_mandatory': (True, False),
1174                      'is_submsg': (True, False)})
1175
1176  # Add all _Add{Mandatory,Optional}Field tests.
1177  AddParametricTests('AddField',
1178                     {'is_mandatory': (True, False),
1179                      'linebreak': (True, False),
1180                      'indent': (0, 1, 2),
1181                      'convert': (str, lambda s: s[::-1]),
1182                      'is_present': (True, False)})
1183
1184  # Add all _Add{Mandatory,Optional}SubMsg tests.
1185  AddParametricTests('AddSubMsg',
1186                     {'is_mandatory': (True, False),
1187                      'is_present': (True, False)})
1188
1189  # Add all _CheckManifest() test cases.
1190  AddParametricTests('CheckManifest',
1191                     {'fail_mismatched_block_size': (True, False),
1192                      'fail_bad_sigs': (True, False),
1193                      'fail_mismatched_oki_ori': (True, False),
1194                      'fail_bad_oki': (True, False),
1195                      'fail_bad_ori': (True, False),
1196                      'fail_bad_nki': (True, False),
1197                      'fail_bad_nri': (True, False),
1198                      'fail_old_kernel_fs_size': (True, False),
1199                      'fail_old_rootfs_fs_size': (True, False),
1200                      'fail_new_kernel_fs_size': (True, False),
1201                      'fail_new_rootfs_fs_size': (True, False)})
1202
1203  # Add all _CheckOperation() test cases.
1204  AddParametricTests('CheckOperation',
1205                     {'op_type_name': ('REPLACE', 'REPLACE_BZ', 'REPLACE_XZ',
1206                                       'SOURCE_COPY', 'SOURCE_BSDIFF',
1207                                       'PUFFDIFF', 'BROTLI_BSDIFF'),
1208                      'allow_unhashed': (True, False),
1209                      'fail_src_extents': (True, False),
1210                      'fail_dst_extents': (True, False),
1211                      'fail_mismatched_data_offset_length': (True, False),
1212                      'fail_missing_dst_extents': (True, False),
1213                      'fail_src_length': (True, False),
1214                      'fail_dst_length': (True, False),
1215                      'fail_data_hash': (True, False),
1216                      'fail_prev_data_offset': (True, False),
1217                      'fail_bad_minor_version': (True, False)},
1218                     validate_func=ValidateCheckOperationTest)
1219
1220  # Add all _CheckOperations() test cases.
1221  AddParametricTests('CheckOperations',
1222                     {'fail_nonexhaustive_full_update': (True, False)})
1223
1224  # Add all _CheckOperations() test cases.
1225  AddParametricTests('CheckSignatures',
1226                     {'fail_empty_sigs_blob': (True, False),
1227                      'fail_sig_missing_fields': (True, False),
1228                      'fail_unknown_sig_version': (True, False),
1229                      'fail_incorrect_sig': (True, False)})
1230
1231  # Add all _CheckManifestMinorVersion() test cases.
1232  AddParametricTests('CheckManifestMinorVersion',
1233                     {'minor_version': (None, 0, 2, 3, 4, 5, 555),
1234                      'payload_type': (checker._TYPE_FULL,
1235                                       checker._TYPE_DELTA)})
1236
1237  # Add all Run() test cases.
1238  AddParametricTests('Run',
1239                     {'rootfs_part_size_provided': (True, False),
1240                      'kernel_part_size_provided': (True, False),
1241                      'fail_wrong_payload_type': (True, False),
1242                      'fail_invalid_block_size': (True, False),
1243                      'fail_mismatched_metadata_size': (True, False),
1244                      'fail_mismatched_block_size': (True, False),
1245                      'fail_excess_data': (True, False),
1246                      'fail_rootfs_part_size_exceeded': (True, False),
1247                      'fail_kernel_part_size_exceeded': (True, False)})
1248
1249
1250if __name__ == '__main__':
1251  AddAllParametricTests()
1252  unittest.main()
1253