1#!/usr/bin/env python 2# 3# Copyright (C) 2013 The Android Open Source Project 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16# 17 18"""Unit testing checker.py.""" 19 20# Disable check for function names to avoid errors based on old code 21# pylint: disable-msg=invalid-name 22 23from __future__ import absolute_import 24 25import array 26import collections 27import hashlib 28import io 29import itertools 30import os 31import unittest 32 33from six.moves import zip 34 35import mock # pylint: disable=import-error 36 37from update_payload import checker 38from update_payload import common 39from update_payload import test_utils 40from update_payload import update_metadata_pb2 41from update_payload.error import PayloadError 42from update_payload.payload import Payload # Avoid name conflicts later. 43 44 45def _OpTypeByName(op_name): 46 """Returns the type of an operation from its name.""" 47 op_name_to_type = { 48 'REPLACE': common.OpType.REPLACE, 49 'REPLACE_BZ': common.OpType.REPLACE_BZ, 50 'SOURCE_COPY': common.OpType.SOURCE_COPY, 51 'SOURCE_BSDIFF': common.OpType.SOURCE_BSDIFF, 52 'ZERO': common.OpType.ZERO, 53 'DISCARD': common.OpType.DISCARD, 54 'REPLACE_XZ': common.OpType.REPLACE_XZ, 55 'PUFFDIFF': common.OpType.PUFFDIFF, 56 'BROTLI_BSDIFF': common.OpType.BROTLI_BSDIFF, 57 } 58 return op_name_to_type[op_name] 59 60 61def _GetPayloadChecker(payload_gen_write_to_file_func, payload_gen_dargs=None, 62 checker_init_dargs=None): 63 """Returns a payload checker from a given payload generator.""" 64 if payload_gen_dargs is None: 65 payload_gen_dargs = {} 66 if checker_init_dargs is None: 67 checker_init_dargs = {} 68 69 payload_file = io.BytesIO() 70 payload_gen_write_to_file_func(payload_file, **payload_gen_dargs) 71 payload_file.seek(0) 72 payload = Payload(payload_file) 73 payload.Init() 74 return checker.PayloadChecker(payload, **checker_init_dargs) 75 76 77def _GetPayloadCheckerWithData(payload_gen): 78 """Returns a payload checker from a given payload generator.""" 79 payload_file = io.BytesIO() 80 payload_gen.WriteToFile(payload_file) 81 payload_file.seek(0) 82 payload = Payload(payload_file) 83 payload.Init() 84 return checker.PayloadChecker(payload) 85 86 87# This class doesn't need an __init__(). 88# pylint: disable=W0232 89# Unit testing is all about running protected methods. 90# pylint: disable=W0212 91# Don't bark about missing members of classes you cannot import. 92# pylint: disable=E1101 93class PayloadCheckerTest(unittest.TestCase): 94 """Tests the PayloadChecker class. 95 96 In addition to ordinary testFoo() methods, which are automatically invoked by 97 the unittest framework, in this class we make use of DoBarTest() calls that 98 implement parametric tests of certain features. In order to invoke each test, 99 which embodies a unique combination of parameter values, as a complete unit 100 test, we perform explicit enumeration of the parameter space and create 101 individual invocation contexts for each, which are then bound as 102 testBar__param1=val1__param2=val2(). The enumeration of parameter spaces for 103 all such tests is done in AddAllParametricTests(). 104 """ 105 106 def setUp(self): 107 """setUp function for unittest testcase""" 108 self.mock_checks = [] 109 110 def tearDown(self): 111 """tearDown function for unittest testcase""" 112 # Verify that all mock functions were called. 113 for check in self.mock_checks: 114 check.mock_fn.assert_called_once_with(*check.exp_args, **check.exp_kwargs) 115 116 class MockChecksAtTearDown(object): 117 """Mock data storage. 118 119 This class stores the mock functions and its arguments to be checked at a 120 later point. 121 """ 122 def __init__(self, mock_fn, *args, **kwargs): 123 self.mock_fn = mock_fn 124 self.exp_args = args 125 self.exp_kwargs = kwargs 126 127 def addPostCheckForMockFunction(self, mock_fn, *args, **kwargs): 128 """Store a mock function and its arguments to self.mock_checks 129 130 Args: 131 mock_fn: mock function object 132 args: expected positional arguments for the mock_fn 133 kwargs: expected named arguments for the mock_fn 134 """ 135 self.mock_checks.append(self.MockChecksAtTearDown(mock_fn, *args, **kwargs)) 136 137 def MockPayload(self): 138 """Create a mock payload object, complete with a mock manifest.""" 139 payload = mock.create_autospec(Payload) 140 payload.is_init = True 141 payload.manifest = mock.create_autospec( 142 update_metadata_pb2.DeltaArchiveManifest) 143 return payload 144 145 @staticmethod 146 def NewExtent(start_block, num_blocks): 147 """Returns an Extent message. 148 149 Each of the provided fields is set iff it is >= 0; otherwise, it's left at 150 its default state. 151 152 Args: 153 start_block: The starting block of the extent. 154 num_blocks: The number of blocks in the extent. 155 156 Returns: 157 An Extent message. 158 """ 159 ex = update_metadata_pb2.Extent() 160 if start_block >= 0: 161 ex.start_block = start_block 162 if num_blocks >= 0: 163 ex.num_blocks = num_blocks 164 return ex 165 166 @staticmethod 167 def NewExtentList(*args): 168 """Returns an list of extents. 169 170 Args: 171 *args: (start_block, num_blocks) pairs defining the extents. 172 173 Returns: 174 A list of Extent objects. 175 """ 176 ex_list = [] 177 for start_block, num_blocks in args: 178 ex_list.append(PayloadCheckerTest.NewExtent(start_block, num_blocks)) 179 return ex_list 180 181 @staticmethod 182 def AddToMessage(repeated_field, field_vals): 183 for field_val in field_vals: 184 new_field = repeated_field.add() 185 new_field.CopyFrom(field_val) 186 187 def SetupAddElemTest(self, is_present, is_submsg, convert=str, 188 linebreak=False, indent=0): 189 """Setup for testing of _CheckElem() and its derivatives. 190 191 Args: 192 is_present: Whether or not the element is found in the message. 193 is_submsg: Whether the element is a sub-message itself. 194 convert: A representation conversion function. 195 linebreak: Whether or not a linebreak is to be used in the report. 196 indent: Indentation used for the report. 197 198 Returns: 199 msg: A mock message object. 200 report: A mock report object. 201 subreport: A mock sub-report object. 202 name: An element name to check. 203 val: Expected element value. 204 """ 205 name = 'foo' 206 val = 'fake submsg' if is_submsg else 'fake field' 207 subreport = 'fake subreport' 208 209 # Create a mock message. 210 msg = mock.create_autospec(update_metadata_pb2._message.Message) 211 self.addPostCheckForMockFunction(msg.HasField, name) 212 msg.HasField.return_value = is_present 213 setattr(msg, name, val) 214 # Create a mock report. 215 report = mock.create_autospec(checker._PayloadReport) 216 if is_present: 217 if is_submsg: 218 self.addPostCheckForMockFunction(report.AddSubReport, name) 219 report.AddSubReport.return_value = subreport 220 else: 221 self.addPostCheckForMockFunction(report.AddField, name, convert(val), 222 linebreak=linebreak, indent=indent) 223 224 return (msg, report, subreport, name, val) 225 226 def DoAddElemTest(self, is_present, is_mandatory, is_submsg, convert, 227 linebreak, indent): 228 """Parametric testing of _CheckElem(). 229 230 Args: 231 is_present: Whether or not the element is found in the message. 232 is_mandatory: Whether or not it's a mandatory element. 233 is_submsg: Whether the element is a sub-message itself. 234 convert: A representation conversion function. 235 linebreak: Whether or not a linebreak is to be used in the report. 236 indent: Indentation used for the report. 237 """ 238 msg, report, subreport, name, val = self.SetupAddElemTest( 239 is_present, is_submsg, convert, linebreak, indent) 240 241 args = (msg, name, report, is_mandatory, is_submsg) 242 kwargs = {'convert': convert, 'linebreak': linebreak, 'indent': indent} 243 if is_mandatory and not is_present: 244 self.assertRaises(PayloadError, 245 checker.PayloadChecker._CheckElem, *args, **kwargs) 246 else: 247 ret_val, ret_subreport = checker.PayloadChecker._CheckElem(*args, 248 **kwargs) 249 self.assertEqual(val if is_present else None, ret_val) 250 self.assertEqual(subreport if is_present and is_submsg else None, 251 ret_subreport) 252 253 def DoAddFieldTest(self, is_mandatory, is_present, convert, linebreak, 254 indent): 255 """Parametric testing of _Check{Mandatory,Optional}Field(). 256 257 Args: 258 is_mandatory: Whether we're testing a mandatory call. 259 is_present: Whether or not the element is found in the message. 260 convert: A representation conversion function. 261 linebreak: Whether or not a linebreak is to be used in the report. 262 indent: Indentation used for the report. 263 """ 264 msg, report, _, name, val = self.SetupAddElemTest( 265 is_present, False, convert, linebreak, indent) 266 267 # Prepare for invocation of the tested method. 268 args = [msg, name, report] 269 kwargs = {'convert': convert, 'linebreak': linebreak, 'indent': indent} 270 if is_mandatory: 271 args.append('bar') 272 tested_func = checker.PayloadChecker._CheckMandatoryField 273 else: 274 tested_func = checker.PayloadChecker._CheckOptionalField 275 276 # Test the method call. 277 if is_mandatory and not is_present: 278 self.assertRaises(PayloadError, tested_func, *args, **kwargs) 279 else: 280 ret_val = tested_func(*args, **kwargs) 281 self.assertEqual(val if is_present else None, ret_val) 282 283 def DoAddSubMsgTest(self, is_mandatory, is_present): 284 """Parametrized testing of _Check{Mandatory,Optional}SubMsg(). 285 286 Args: 287 is_mandatory: Whether we're testing a mandatory call. 288 is_present: Whether or not the element is found in the message. 289 """ 290 msg, report, subreport, name, val = self.SetupAddElemTest(is_present, True) 291 292 # Prepare for invocation of the tested method. 293 args = [msg, name, report] 294 if is_mandatory: 295 args.append('bar') 296 tested_func = checker.PayloadChecker._CheckMandatorySubMsg 297 else: 298 tested_func = checker.PayloadChecker._CheckOptionalSubMsg 299 300 # Test the method call. 301 if is_mandatory and not is_present: 302 self.assertRaises(PayloadError, tested_func, *args) 303 else: 304 ret_val, ret_subreport = tested_func(*args) 305 self.assertEqual(val if is_present else None, ret_val) 306 self.assertEqual(subreport if is_present else None, ret_subreport) 307 308 def testCheckPresentIff(self): 309 """Tests _CheckPresentIff().""" 310 self.assertIsNone(checker.PayloadChecker._CheckPresentIff( 311 None, None, 'foo', 'bar', 'baz')) 312 self.assertIsNone(checker.PayloadChecker._CheckPresentIff( 313 'a', 'b', 'foo', 'bar', 'baz')) 314 self.assertRaises(PayloadError, checker.PayloadChecker._CheckPresentIff, 315 'a', None, 'foo', 'bar', 'baz') 316 self.assertRaises(PayloadError, checker.PayloadChecker._CheckPresentIff, 317 None, 'b', 'foo', 'bar', 'baz') 318 319 def DoCheckSha256SignatureTest(self, expect_pass, expect_subprocess_call, 320 sig_data, sig_asn1_header, 321 returned_signed_hash, expected_signed_hash): 322 """Parametric testing of _CheckSha256SignatureTest(). 323 324 Args: 325 expect_pass: Whether or not it should pass. 326 expect_subprocess_call: Whether to expect the openssl call to happen. 327 sig_data: The signature raw data. 328 sig_asn1_header: The ASN1 header. 329 returned_signed_hash: The signed hash data retuned by openssl. 330 expected_signed_hash: The signed hash data to compare against. 331 """ 332 # Stub out the subprocess invocation. 333 with mock.patch.object(checker.PayloadChecker, '_Run') \ 334 as mock_payload_checker: 335 if expect_subprocess_call: 336 mock_payload_checker([], send_data=sig_data) 337 mock_payload_checker.return_value = ( 338 sig_asn1_header + returned_signed_hash, None) 339 340 if expect_pass: 341 self.assertIsNone(checker.PayloadChecker._CheckSha256Signature( 342 sig_data, 'foo', expected_signed_hash, 'bar')) 343 else: 344 self.assertRaises(PayloadError, 345 checker.PayloadChecker._CheckSha256Signature, 346 sig_data, 'foo', expected_signed_hash, 'bar') 347 348 def testCheckSha256Signature_Pass(self): 349 """Tests _CheckSha256Signature(); pass case.""" 350 sig_data = 'fake-signature'.ljust(256) 351 signed_hash = hashlib.sha256(b'fake-data').digest() 352 self.DoCheckSha256SignatureTest(True, True, sig_data, 353 common.SIG_ASN1_HEADER, signed_hash, 354 signed_hash) 355 356 def testCheckSha256Signature_FailBadSignature(self): 357 """Tests _CheckSha256Signature(); fails due to malformed signature.""" 358 sig_data = 'fake-signature' # Malformed (not 256 bytes in length). 359 signed_hash = hashlib.sha256(b'fake-data').digest() 360 self.DoCheckSha256SignatureTest(False, False, sig_data, 361 common.SIG_ASN1_HEADER, signed_hash, 362 signed_hash) 363 364 def testCheckSha256Signature_FailBadOutputLength(self): 365 """Tests _CheckSha256Signature(); fails due to unexpected output length.""" 366 sig_data = 'fake-signature'.ljust(256) 367 signed_hash = b'fake-hash' # Malformed (not 32 bytes in length). 368 self.DoCheckSha256SignatureTest(False, True, sig_data, 369 common.SIG_ASN1_HEADER, signed_hash, 370 signed_hash) 371 372 def testCheckSha256Signature_FailBadAsnHeader(self): 373 """Tests _CheckSha256Signature(); fails due to bad ASN1 header.""" 374 sig_data = 'fake-signature'.ljust(256) 375 signed_hash = hashlib.sha256(b'fake-data').digest() 376 bad_asn1_header = b'bad-asn-header'.ljust(len(common.SIG_ASN1_HEADER)) 377 self.DoCheckSha256SignatureTest(False, True, sig_data, bad_asn1_header, 378 signed_hash, signed_hash) 379 380 def testCheckSha256Signature_FailBadHash(self): 381 """Tests _CheckSha256Signature(); fails due to bad hash returned.""" 382 sig_data = 'fake-signature'.ljust(256) 383 expected_signed_hash = hashlib.sha256(b'fake-data').digest() 384 returned_signed_hash = hashlib.sha256(b'bad-fake-data').digest() 385 self.DoCheckSha256SignatureTest(False, True, sig_data, 386 common.SIG_ASN1_HEADER, 387 expected_signed_hash, returned_signed_hash) 388 389 def testCheckBlocksFitLength_Pass(self): 390 """Tests _CheckBlocksFitLength(); pass case.""" 391 self.assertIsNone(checker.PayloadChecker._CheckBlocksFitLength( 392 64, 4, 16, 'foo')) 393 self.assertIsNone(checker.PayloadChecker._CheckBlocksFitLength( 394 60, 4, 16, 'foo')) 395 self.assertIsNone(checker.PayloadChecker._CheckBlocksFitLength( 396 49, 4, 16, 'foo')) 397 self.assertIsNone(checker.PayloadChecker._CheckBlocksFitLength( 398 48, 3, 16, 'foo')) 399 400 def testCheckBlocksFitLength_TooManyBlocks(self): 401 """Tests _CheckBlocksFitLength(); fails due to excess blocks.""" 402 self.assertRaises(PayloadError, 403 checker.PayloadChecker._CheckBlocksFitLength, 404 64, 5, 16, 'foo') 405 self.assertRaises(PayloadError, 406 checker.PayloadChecker._CheckBlocksFitLength, 407 60, 5, 16, 'foo') 408 self.assertRaises(PayloadError, 409 checker.PayloadChecker._CheckBlocksFitLength, 410 49, 5, 16, 'foo') 411 self.assertRaises(PayloadError, 412 checker.PayloadChecker._CheckBlocksFitLength, 413 48, 4, 16, 'foo') 414 415 def testCheckBlocksFitLength_TooFewBlocks(self): 416 """Tests _CheckBlocksFitLength(); fails due to insufficient blocks.""" 417 self.assertRaises(PayloadError, 418 checker.PayloadChecker._CheckBlocksFitLength, 419 64, 3, 16, 'foo') 420 self.assertRaises(PayloadError, 421 checker.PayloadChecker._CheckBlocksFitLength, 422 60, 3, 16, 'foo') 423 self.assertRaises(PayloadError, 424 checker.PayloadChecker._CheckBlocksFitLength, 425 49, 3, 16, 'foo') 426 self.assertRaises(PayloadError, 427 checker.PayloadChecker._CheckBlocksFitLength, 428 48, 2, 16, 'foo') 429 430 def DoCheckManifestTest(self, fail_mismatched_block_size, fail_bad_sigs, 431 fail_mismatched_oki_ori, fail_bad_oki, fail_bad_ori, 432 fail_bad_nki, fail_bad_nri, fail_old_kernel_fs_size, 433 fail_old_rootfs_fs_size, fail_new_kernel_fs_size, 434 fail_new_rootfs_fs_size): 435 """Parametric testing of _CheckManifest(). 436 437 Args: 438 fail_mismatched_block_size: Simulate a missing block_size field. 439 fail_bad_sigs: Make signatures descriptor inconsistent. 440 fail_mismatched_oki_ori: Make old rootfs/kernel info partially present. 441 fail_bad_oki: Tamper with old kernel info. 442 fail_bad_ori: Tamper with old rootfs info. 443 fail_bad_nki: Tamper with new kernel info. 444 fail_bad_nri: Tamper with new rootfs info. 445 fail_old_kernel_fs_size: Make old kernel fs size too big. 446 fail_old_rootfs_fs_size: Make old rootfs fs size too big. 447 fail_new_kernel_fs_size: Make new kernel fs size too big. 448 fail_new_rootfs_fs_size: Make new rootfs fs size too big. 449 """ 450 # Generate a test payload. For this test, we only care about the manifest 451 # and don't need any data blobs, hence we can use a plain paylaod generator 452 # (which also gives us more control on things that can be screwed up). 453 payload_gen = test_utils.PayloadGenerator() 454 455 # Tamper with block size, if required. 456 if fail_mismatched_block_size: 457 payload_gen.SetBlockSize(test_utils.KiB(1)) 458 else: 459 payload_gen.SetBlockSize(test_utils.KiB(4)) 460 461 # Add some operations. 462 payload_gen.AddOperation(common.ROOTFS, common.OpType.SOURCE_COPY, 463 src_extents=[(0, 16), (16, 497)], 464 dst_extents=[(16, 496), (0, 16)]) 465 payload_gen.AddOperation(common.KERNEL, common.OpType.SOURCE_COPY, 466 src_extents=[(0, 8), (8, 8)], 467 dst_extents=[(8, 8), (0, 8)]) 468 469 # Set an invalid signatures block (offset but no size), if required. 470 if fail_bad_sigs: 471 payload_gen.SetSignatures(32, None) 472 473 # Set partition / filesystem sizes. 474 rootfs_part_size = test_utils.MiB(8) 475 kernel_part_size = test_utils.KiB(512) 476 old_rootfs_fs_size = new_rootfs_fs_size = rootfs_part_size 477 old_kernel_fs_size = new_kernel_fs_size = kernel_part_size 478 if fail_old_kernel_fs_size: 479 old_kernel_fs_size += 100 480 if fail_old_rootfs_fs_size: 481 old_rootfs_fs_size += 100 482 if fail_new_kernel_fs_size: 483 new_kernel_fs_size += 100 484 if fail_new_rootfs_fs_size: 485 new_rootfs_fs_size += 100 486 487 # Add old kernel/rootfs partition info, as required. 488 if fail_mismatched_oki_ori or fail_old_kernel_fs_size or fail_bad_oki: 489 oki_hash = (None if fail_bad_oki 490 else hashlib.sha256(b'fake-oki-content').digest()) 491 payload_gen.SetPartInfo(common.KERNEL, False, old_kernel_fs_size, 492 oki_hash) 493 if not fail_mismatched_oki_ori and (fail_old_rootfs_fs_size or 494 fail_bad_ori): 495 ori_hash = (None if fail_bad_ori 496 else hashlib.sha256(b'fake-ori-content').digest()) 497 payload_gen.SetPartInfo(common.ROOTFS, False, old_rootfs_fs_size, 498 ori_hash) 499 500 # Add new kernel/rootfs partition info. 501 payload_gen.SetPartInfo( 502 common.KERNEL, True, new_kernel_fs_size, 503 None if fail_bad_nki else hashlib.sha256(b'fake-nki-content').digest()) 504 payload_gen.SetPartInfo( 505 common.ROOTFS, True, new_rootfs_fs_size, 506 None if fail_bad_nri else hashlib.sha256(b'fake-nri-content').digest()) 507 508 # Set the minor version. 509 payload_gen.SetMinorVersion(0) 510 511 # Create the test object. 512 payload_checker = _GetPayloadChecker(payload_gen.WriteToFile) 513 report = checker._PayloadReport() 514 515 should_fail = (fail_mismatched_block_size or fail_bad_sigs or 516 fail_mismatched_oki_ori or fail_bad_oki or fail_bad_ori or 517 fail_bad_nki or fail_bad_nri or fail_old_kernel_fs_size or 518 fail_old_rootfs_fs_size or fail_new_kernel_fs_size or 519 fail_new_rootfs_fs_size) 520 part_sizes = { 521 common.ROOTFS: rootfs_part_size, 522 common.KERNEL: kernel_part_size 523 } 524 525 if should_fail: 526 self.assertRaises(PayloadError, payload_checker._CheckManifest, report, 527 part_sizes) 528 else: 529 self.assertIsNone(payload_checker._CheckManifest(report, part_sizes)) 530 531 def testCheckLength(self): 532 """Tests _CheckLength().""" 533 payload_checker = checker.PayloadChecker(self.MockPayload()) 534 block_size = payload_checker.block_size 535 536 # Passes. 537 self.assertIsNone(payload_checker._CheckLength( 538 int(3.5 * block_size), 4, 'foo', 'bar')) 539 # Fails, too few blocks. 540 self.assertRaises(PayloadError, payload_checker._CheckLength, 541 int(3.5 * block_size), 3, 'foo', 'bar') 542 # Fails, too many blocks. 543 self.assertRaises(PayloadError, payload_checker._CheckLength, 544 int(3.5 * block_size), 5, 'foo', 'bar') 545 546 def testCheckExtents(self): 547 """Tests _CheckExtents().""" 548 payload_checker = checker.PayloadChecker(self.MockPayload()) 549 block_size = payload_checker.block_size 550 551 # Passes w/ all real extents. 552 extents = self.NewExtentList((0, 4), (8, 3), (1024, 16)) 553 self.assertEqual( 554 23, 555 payload_checker._CheckExtents(extents, (1024 + 16) * block_size, 556 collections.defaultdict(int), 'foo')) 557 558 # Fails, extent missing a start block. 559 extents = self.NewExtentList((-1, 4), (8, 3), (1024, 16)) 560 self.assertRaises( 561 PayloadError, payload_checker._CheckExtents, extents, 562 (1024 + 16) * block_size, collections.defaultdict(int), 'foo') 563 564 # Fails, extent missing block count. 565 extents = self.NewExtentList((0, -1), (8, 3), (1024, 16)) 566 self.assertRaises( 567 PayloadError, payload_checker._CheckExtents, extents, 568 (1024 + 16) * block_size, collections.defaultdict(int), 'foo') 569 570 # Fails, extent has zero blocks. 571 extents = self.NewExtentList((0, 4), (8, 3), (1024, 0)) 572 self.assertRaises( 573 PayloadError, payload_checker._CheckExtents, extents, 574 (1024 + 16) * block_size, collections.defaultdict(int), 'foo') 575 576 # Fails, extent exceeds partition boundaries. 577 extents = self.NewExtentList((0, 4), (8, 3), (1024, 16)) 578 self.assertRaises( 579 PayloadError, payload_checker._CheckExtents, extents, 580 (1024 + 15) * block_size, collections.defaultdict(int), 'foo') 581 582 def testCheckReplaceOperation(self): 583 """Tests _CheckReplaceOperation() where op.type == REPLACE.""" 584 payload_checker = checker.PayloadChecker(self.MockPayload()) 585 block_size = payload_checker.block_size 586 data_length = 10000 587 588 op = mock.create_autospec(update_metadata_pb2.InstallOperation) 589 op.type = common.OpType.REPLACE 590 591 # Pass. 592 op.src_extents = [] 593 self.assertIsNone( 594 payload_checker._CheckReplaceOperation( 595 op, data_length, (data_length + block_size - 1) // block_size, 596 'foo')) 597 598 # Fail, src extents founds. 599 op.src_extents = ['bar'] 600 self.assertRaises( 601 PayloadError, payload_checker._CheckReplaceOperation, 602 op, data_length, (data_length + block_size - 1) // block_size, 'foo') 603 604 # Fail, missing data. 605 op.src_extents = [] 606 self.assertRaises( 607 PayloadError, payload_checker._CheckReplaceOperation, 608 op, None, (data_length + block_size - 1) // block_size, 'foo') 609 610 # Fail, length / block number mismatch. 611 op.src_extents = ['bar'] 612 self.assertRaises( 613 PayloadError, payload_checker._CheckReplaceOperation, 614 op, data_length, (data_length + block_size - 1) // block_size + 1, 615 'foo') 616 617 def testCheckReplaceBzOperation(self): 618 """Tests _CheckReplaceOperation() where op.type == REPLACE_BZ.""" 619 payload_checker = checker.PayloadChecker(self.MockPayload()) 620 block_size = payload_checker.block_size 621 data_length = block_size * 3 622 623 op = mock.create_autospec( 624 update_metadata_pb2.InstallOperation) 625 op.type = common.OpType.REPLACE_BZ 626 627 # Pass. 628 op.src_extents = [] 629 self.assertIsNone( 630 payload_checker._CheckReplaceOperation( 631 op, data_length, (data_length + block_size - 1) // block_size + 5, 632 'foo')) 633 634 # Fail, src extents founds. 635 op.src_extents = ['bar'] 636 self.assertRaises( 637 PayloadError, payload_checker._CheckReplaceOperation, 638 op, data_length, (data_length + block_size - 1) // block_size + 5, 639 'foo') 640 641 # Fail, missing data. 642 op.src_extents = [] 643 self.assertRaises( 644 PayloadError, payload_checker._CheckReplaceOperation, 645 op, None, (data_length + block_size - 1) // block_size, 'foo') 646 647 # Fail, too few blocks to justify BZ. 648 op.src_extents = [] 649 self.assertRaises( 650 PayloadError, payload_checker._CheckReplaceOperation, 651 op, data_length, (data_length + block_size - 1) // block_size, 'foo') 652 653 # Fail, total_dst_blocks is a floating point value. 654 op.src_extents = [] 655 self.assertRaises( 656 PayloadError, payload_checker._CheckReplaceOperation, 657 op, data_length, (data_length + block_size - 1) / block_size, 'foo') 658 659 def testCheckReplaceXzOperation(self): 660 """Tests _CheckReplaceOperation() where op.type == REPLACE_XZ.""" 661 payload_checker = checker.PayloadChecker(self.MockPayload()) 662 block_size = payload_checker.block_size 663 data_length = block_size * 3 664 665 op = mock.create_autospec( 666 update_metadata_pb2.InstallOperation) 667 op.type = common.OpType.REPLACE_XZ 668 669 # Pass. 670 op.src_extents = [] 671 self.assertIsNone( 672 payload_checker._CheckReplaceOperation( 673 op, data_length, (data_length + block_size - 1) // block_size + 5, 674 'foo')) 675 676 # Fail, src extents founds. 677 op.src_extents = ['bar'] 678 self.assertRaises( 679 PayloadError, payload_checker._CheckReplaceOperation, 680 op, data_length, (data_length + block_size - 1) // block_size + 5, 681 'foo') 682 683 # Fail, missing data. 684 op.src_extents = [] 685 self.assertRaises( 686 PayloadError, payload_checker._CheckReplaceOperation, 687 op, None, (data_length + block_size - 1) // block_size, 'foo') 688 689 # Fail, too few blocks to justify XZ. 690 op.src_extents = [] 691 self.assertRaises( 692 PayloadError, payload_checker._CheckReplaceOperation, 693 op, data_length, (data_length + block_size - 1) // block_size, 'foo') 694 695 # Fail, total_dst_blocks is a floating point value. 696 op.src_extents = [] 697 self.assertRaises( 698 PayloadError, payload_checker._CheckReplaceOperation, 699 op, data_length, (data_length + block_size - 1) / block_size, 'foo') 700 701 def testCheckAnyDiff(self): 702 """Tests _CheckAnyDiffOperation().""" 703 payload_checker = checker.PayloadChecker(self.MockPayload()) 704 op = update_metadata_pb2.InstallOperation() 705 706 # Pass. 707 self.assertIsNone( 708 payload_checker._CheckAnyDiffOperation(op, 10000, 3, 'foo')) 709 710 # Fail, missing data blob. 711 self.assertRaises( 712 PayloadError, payload_checker._CheckAnyDiffOperation, 713 op, None, 3, 'foo') 714 715 # Fail, too big of a diff blob (unjustified). 716 self.assertRaises( 717 PayloadError, payload_checker._CheckAnyDiffOperation, 718 op, 10000, 2, 'foo') 719 720 def testCheckSourceCopyOperation_Pass(self): 721 """Tests _CheckSourceCopyOperation(); pass case.""" 722 payload_checker = checker.PayloadChecker(self.MockPayload()) 723 self.assertIsNone( 724 payload_checker._CheckSourceCopyOperation(None, 134, 134, 'foo')) 725 726 def testCheckSourceCopyOperation_FailContainsData(self): 727 """Tests _CheckSourceCopyOperation(); message contains data.""" 728 payload_checker = checker.PayloadChecker(self.MockPayload()) 729 self.assertRaises(PayloadError, payload_checker._CheckSourceCopyOperation, 730 134, 0, 0, 'foo') 731 732 def testCheckSourceCopyOperation_FailBlockCountsMismatch(self): 733 """Tests _CheckSourceCopyOperation(); src and dst block totals not equal.""" 734 payload_checker = checker.PayloadChecker(self.MockPayload()) 735 self.assertRaises(PayloadError, payload_checker._CheckSourceCopyOperation, 736 None, 0, 1, 'foo') 737 738 def DoCheckOperationTest(self, op_type_name, allow_unhashed, 739 fail_src_extents, fail_dst_extents, 740 fail_mismatched_data_offset_length, 741 fail_missing_dst_extents, fail_src_length, 742 fail_dst_length, fail_data_hash, 743 fail_prev_data_offset, fail_bad_minor_version): 744 """Parametric testing of _CheckOperation(). 745 746 Args: 747 op_type_name: 'REPLACE', 'REPLACE_BZ', 'REPLACE_XZ', 748 'SOURCE_COPY', 'SOURCE_BSDIFF', BROTLI_BSDIFF or 'PUFFDIFF'. 749 allow_unhashed: Whether we're allowing to not hash the data. 750 fail_src_extents: Tamper with src extents. 751 fail_dst_extents: Tamper with dst extents. 752 fail_mismatched_data_offset_length: Make data_{offset,length} 753 inconsistent. 754 fail_missing_dst_extents: Do not include dst extents. 755 fail_src_length: Make src length inconsistent. 756 fail_dst_length: Make dst length inconsistent. 757 fail_data_hash: Tamper with the data blob hash. 758 fail_prev_data_offset: Make data space uses incontiguous. 759 fail_bad_minor_version: Make minor version incompatible with op. 760 """ 761 op_type = _OpTypeByName(op_type_name) 762 763 # Create the test object. 764 payload = self.MockPayload() 765 payload_checker = checker.PayloadChecker(payload, 766 allow_unhashed=allow_unhashed) 767 block_size = payload_checker.block_size 768 769 # Create auxiliary arguments. 770 old_part_size = test_utils.MiB(4) 771 new_part_size = test_utils.MiB(8) 772 old_block_counters = array.array( 773 'B', [0] * ((old_part_size + block_size - 1) // block_size)) 774 new_block_counters = array.array( 775 'B', [0] * ((new_part_size + block_size - 1) // block_size)) 776 prev_data_offset = 1876 777 blob_hash_counts = collections.defaultdict(int) 778 779 # Create the operation object for the test. 780 op = update_metadata_pb2.InstallOperation() 781 op.type = op_type 782 783 total_src_blocks = 0 784 if op_type in (common.OpType.SOURCE_COPY, common.OpType.SOURCE_BSDIFF, 785 common.OpType.PUFFDIFF, common.OpType.BROTLI_BSDIFF): 786 if fail_src_extents: 787 self.AddToMessage(op.src_extents, 788 self.NewExtentList((1, 0))) 789 else: 790 self.AddToMessage(op.src_extents, 791 self.NewExtentList((1, 16))) 792 total_src_blocks = 16 793 794 payload_checker.major_version = common.BRILLO_MAJOR_PAYLOAD_VERSION 795 if op_type in (common.OpType.REPLACE, common.OpType.REPLACE_BZ): 796 payload_checker.minor_version = 0 797 elif op_type in (common.OpType.SOURCE_COPY, common.OpType.SOURCE_BSDIFF): 798 payload_checker.minor_version = 1 if fail_bad_minor_version else 2 799 if op_type == common.OpType.REPLACE_XZ: 800 payload_checker.minor_version = 2 if fail_bad_minor_version else 3 801 elif op_type in (common.OpType.ZERO, common.OpType.DISCARD, 802 common.OpType.BROTLI_BSDIFF): 803 payload_checker.minor_version = 3 if fail_bad_minor_version else 4 804 elif op_type == common.OpType.PUFFDIFF: 805 payload_checker.minor_version = 4 if fail_bad_minor_version else 5 806 807 if op_type != common.OpType.SOURCE_COPY: 808 if not fail_mismatched_data_offset_length: 809 op.data_length = 16 * block_size - 8 810 if fail_prev_data_offset: 811 op.data_offset = prev_data_offset + 16 812 else: 813 op.data_offset = prev_data_offset 814 815 fake_data = 'fake-data'.ljust(op.data_length) 816 if not allow_unhashed and not fail_data_hash: 817 # Create a valid data blob hash. 818 op.data_sha256_hash = hashlib.sha256(fake_data.encode('utf-8')).digest() 819 payload.ReadDataBlob.return_value = fake_data.encode('utf-8') 820 821 elif fail_data_hash: 822 # Create an invalid data blob hash. 823 op.data_sha256_hash = hashlib.sha256( 824 fake_data.replace(' ', '-').encode('utf-8')).digest() 825 payload.ReadDataBlob.return_value = fake_data.encode('utf-8') 826 827 total_dst_blocks = 0 828 if not fail_missing_dst_extents: 829 total_dst_blocks = 16 830 if fail_dst_extents: 831 self.AddToMessage(op.dst_extents, 832 self.NewExtentList((4, 16), (32, 0))) 833 else: 834 self.AddToMessage(op.dst_extents, 835 self.NewExtentList((4, 8), (64, 8))) 836 837 if total_src_blocks: 838 if fail_src_length: 839 op.src_length = total_src_blocks * block_size + 8 840 elif (op_type == common.OpType.SOURCE_BSDIFF and 841 payload_checker.minor_version <= 3): 842 op.src_length = total_src_blocks * block_size 843 elif fail_src_length: 844 # Add an orphaned src_length. 845 op.src_length = 16 846 847 if total_dst_blocks: 848 if fail_dst_length: 849 op.dst_length = total_dst_blocks * block_size + 8 850 elif (op_type == common.OpType.SOURCE_BSDIFF and 851 payload_checker.minor_version <= 3): 852 op.dst_length = total_dst_blocks * block_size 853 854 should_fail = (fail_src_extents or fail_dst_extents or 855 fail_mismatched_data_offset_length or 856 fail_missing_dst_extents or fail_src_length or 857 fail_dst_length or fail_data_hash or fail_prev_data_offset or 858 fail_bad_minor_version) 859 args = (op, 'foo', old_block_counters, new_block_counters, 860 old_part_size, new_part_size, prev_data_offset, 861 blob_hash_counts) 862 if should_fail: 863 self.assertRaises(PayloadError, payload_checker._CheckOperation, *args) 864 else: 865 self.assertEqual(op.data_length if op.HasField('data_length') else 0, 866 payload_checker._CheckOperation(*args)) 867 868 def testAllocBlockCounters(self): 869 """Tests _CheckMoveOperation().""" 870 payload_checker = checker.PayloadChecker(self.MockPayload()) 871 block_size = payload_checker.block_size 872 873 # Check allocation for block-aligned partition size, ensure it's integers. 874 result = payload_checker._AllocBlockCounters(16 * block_size) 875 self.assertEqual(16, len(result)) 876 self.assertEqual(int, type(result[0])) 877 878 # Check allocation of unaligned partition sizes. 879 result = payload_checker._AllocBlockCounters(16 * block_size - 1) 880 self.assertEqual(16, len(result)) 881 result = payload_checker._AllocBlockCounters(16 * block_size + 1) 882 self.assertEqual(17, len(result)) 883 884 def DoCheckOperationsTest(self, fail_nonexhaustive_full_update): 885 """Tests _CheckOperations().""" 886 # Generate a test payload. For this test, we only care about one 887 # (arbitrary) set of operations, so we'll only be generating kernel and 888 # test with them. 889 payload_gen = test_utils.PayloadGenerator() 890 891 block_size = test_utils.KiB(4) 892 payload_gen.SetBlockSize(block_size) 893 894 rootfs_part_size = test_utils.MiB(8) 895 896 # Fake rootfs operations in a full update, tampered with as required. 897 rootfs_op_type = common.OpType.REPLACE 898 rootfs_data_length = rootfs_part_size 899 if fail_nonexhaustive_full_update: 900 rootfs_data_length -= block_size 901 902 payload_gen.AddOperation(common.ROOTFS, rootfs_op_type, 903 dst_extents= 904 [(0, rootfs_data_length // block_size)], 905 data_offset=0, 906 data_length=rootfs_data_length) 907 908 # Create the test object. 909 payload_checker = _GetPayloadChecker(payload_gen.WriteToFile, 910 checker_init_dargs={ 911 'allow_unhashed': True}) 912 payload_checker.payload_type = checker._TYPE_FULL 913 report = checker._PayloadReport() 914 partition = next((p for p in payload_checker.payload.manifest.partitions 915 if p.partition_name == common.ROOTFS), None) 916 args = (partition.operations, report, 'foo', 917 0, rootfs_part_size, rootfs_part_size, rootfs_part_size, 0) 918 if fail_nonexhaustive_full_update: 919 self.assertRaises(PayloadError, payload_checker._CheckOperations, *args) 920 else: 921 self.assertEqual(rootfs_data_length, 922 payload_checker._CheckOperations(*args)) 923 924 def DoCheckSignaturesTest(self, fail_empty_sigs_blob, fail_sig_missing_fields, 925 fail_unknown_sig_version, fail_incorrect_sig): 926 """Tests _CheckSignatures().""" 927 # Generate a test payload. For this test, we only care about the signature 928 # block and how it relates to the payload hash. Therefore, we're generating 929 # a random (otherwise useless) payload for this purpose. 930 payload_gen = test_utils.EnhancedPayloadGenerator() 931 block_size = test_utils.KiB(4) 932 payload_gen.SetBlockSize(block_size) 933 rootfs_part_size = test_utils.MiB(2) 934 kernel_part_size = test_utils.KiB(16) 935 payload_gen.SetPartInfo(common.ROOTFS, True, rootfs_part_size, 936 hashlib.sha256(b'fake-new-rootfs-content').digest()) 937 payload_gen.SetPartInfo(common.KERNEL, True, kernel_part_size, 938 hashlib.sha256(b'fake-new-kernel-content').digest()) 939 payload_gen.SetMinorVersion(0) 940 payload_gen.AddOperationWithData( 941 common.ROOTFS, common.OpType.REPLACE, 942 dst_extents=[(0, rootfs_part_size // block_size)], 943 data_blob=os.urandom(rootfs_part_size)) 944 945 do_forge_sigs_data = (fail_empty_sigs_blob or fail_sig_missing_fields or 946 fail_unknown_sig_version or fail_incorrect_sig) 947 948 sigs_data = None 949 if do_forge_sigs_data: 950 sigs_gen = test_utils.SignaturesGenerator() 951 if not fail_empty_sigs_blob: 952 if fail_sig_missing_fields: 953 sig_data = None 954 else: 955 sig_data = test_utils.SignSha256(b'fake-payload-content', 956 test_utils._PRIVKEY_FILE_NAME) 957 sigs_gen.AddSig(5 if fail_unknown_sig_version else 1, sig_data) 958 959 sigs_data = sigs_gen.ToBinary() 960 payload_gen.SetSignatures(payload_gen.curr_offset, len(sigs_data)) 961 962 # Generate payload (complete w/ signature) and create the test object. 963 payload_checker = _GetPayloadChecker( 964 payload_gen.WriteToFileWithData, 965 payload_gen_dargs={ 966 'sigs_data': sigs_data, 967 'privkey_file_name': test_utils._PRIVKEY_FILE_NAME}) 968 payload_checker.payload_type = checker._TYPE_FULL 969 report = checker._PayloadReport() 970 971 # We have to check the manifest first in order to set signature attributes. 972 payload_checker._CheckManifest(report, { 973 common.ROOTFS: rootfs_part_size, 974 common.KERNEL: kernel_part_size 975 }) 976 977 should_fail = (fail_empty_sigs_blob or fail_sig_missing_fields or 978 fail_unknown_sig_version or fail_incorrect_sig) 979 args = (report, test_utils._PUBKEY_FILE_NAME) 980 if should_fail: 981 self.assertRaises(PayloadError, payload_checker._CheckSignatures, *args) 982 else: 983 self.assertIsNone(payload_checker._CheckSignatures(*args)) 984 985 def DoCheckManifestMinorVersionTest(self, minor_version, payload_type): 986 """Parametric testing for CheckManifestMinorVersion(). 987 988 Args: 989 minor_version: The payload minor version to test with. 990 payload_type: The type of the payload we're testing, delta or full. 991 """ 992 # Create the test object. 993 payload = self.MockPayload() 994 payload.manifest.minor_version = minor_version 995 payload_checker = checker.PayloadChecker(payload) 996 payload_checker.payload_type = payload_type 997 report = checker._PayloadReport() 998 999 should_succeed = ( 1000 (minor_version == 0 and payload_type == checker._TYPE_FULL) or 1001 (minor_version == 2 and payload_type == checker._TYPE_DELTA) or 1002 (minor_version == 3 and payload_type == checker._TYPE_DELTA) or 1003 (minor_version == 4 and payload_type == checker._TYPE_DELTA) or 1004 (minor_version == 5 and payload_type == checker._TYPE_DELTA)) 1005 args = (report,) 1006 1007 if should_succeed: 1008 self.assertIsNone(payload_checker._CheckManifestMinorVersion(*args)) 1009 else: 1010 self.assertRaises(PayloadError, 1011 payload_checker._CheckManifestMinorVersion, *args) 1012 1013 def DoRunTest(self, rootfs_part_size_provided, kernel_part_size_provided, 1014 fail_wrong_payload_type, fail_invalid_block_size, 1015 fail_mismatched_metadata_size, fail_mismatched_block_size, 1016 fail_excess_data, fail_rootfs_part_size_exceeded, 1017 fail_kernel_part_size_exceeded): 1018 """Tests Run().""" 1019 # Generate a test payload. For this test, we generate a full update that 1020 # has sample kernel and rootfs operations. Since most testing is done with 1021 # internal PayloadChecker methods that are tested elsewhere, here we only 1022 # tamper with what's actually being manipulated and/or tested in the Run() 1023 # method itself. Note that the checker doesn't verify partition hashes, so 1024 # they're safe to fake. 1025 payload_gen = test_utils.EnhancedPayloadGenerator() 1026 block_size = test_utils.KiB(4) 1027 payload_gen.SetBlockSize(block_size) 1028 kernel_filesystem_size = test_utils.KiB(16) 1029 rootfs_filesystem_size = test_utils.MiB(2) 1030 payload_gen.SetPartInfo(common.ROOTFS, True, rootfs_filesystem_size, 1031 hashlib.sha256(b'fake-new-rootfs-content').digest()) 1032 payload_gen.SetPartInfo(common.KERNEL, True, kernel_filesystem_size, 1033 hashlib.sha256(b'fake-new-kernel-content').digest()) 1034 payload_gen.SetMinorVersion(0) 1035 1036 rootfs_part_size = 0 1037 if rootfs_part_size_provided: 1038 rootfs_part_size = rootfs_filesystem_size + block_size 1039 rootfs_op_size = rootfs_part_size or rootfs_filesystem_size 1040 if fail_rootfs_part_size_exceeded: 1041 rootfs_op_size += block_size 1042 payload_gen.AddOperationWithData( 1043 common.ROOTFS, common.OpType.REPLACE, 1044 dst_extents=[(0, rootfs_op_size // block_size)], 1045 data_blob=os.urandom(rootfs_op_size)) 1046 1047 kernel_part_size = 0 1048 if kernel_part_size_provided: 1049 kernel_part_size = kernel_filesystem_size + block_size 1050 kernel_op_size = kernel_part_size or kernel_filesystem_size 1051 if fail_kernel_part_size_exceeded: 1052 kernel_op_size += block_size 1053 payload_gen.AddOperationWithData( 1054 common.KERNEL, common.OpType.REPLACE, 1055 dst_extents=[(0, kernel_op_size // block_size)], 1056 data_blob=os.urandom(kernel_op_size)) 1057 1058 # Generate payload (complete w/ signature) and create the test object. 1059 if fail_invalid_block_size: 1060 use_block_size = block_size + 5 # Not a power of two. 1061 elif fail_mismatched_block_size: 1062 use_block_size = block_size * 2 # Different that payload stated. 1063 else: 1064 use_block_size = block_size 1065 1066 # For the unittests 237 is the value that generated for the payload. 1067 metadata_size = 237 1068 if fail_mismatched_metadata_size: 1069 metadata_size += 1 1070 1071 kwargs = { 1072 'payload_gen_dargs': { 1073 'privkey_file_name': test_utils._PRIVKEY_FILE_NAME, 1074 'padding': os.urandom(1024) if fail_excess_data else None}, 1075 'checker_init_dargs': { 1076 'assert_type': 'delta' if fail_wrong_payload_type else 'full', 1077 'block_size': use_block_size}} 1078 if fail_invalid_block_size: 1079 self.assertRaises(PayloadError, _GetPayloadChecker, 1080 payload_gen.WriteToFileWithData, **kwargs) 1081 else: 1082 payload_checker = _GetPayloadChecker(payload_gen.WriteToFileWithData, 1083 **kwargs) 1084 1085 kwargs2 = { 1086 'pubkey_file_name': test_utils._PUBKEY_FILE_NAME, 1087 'metadata_size': metadata_size, 1088 'part_sizes': { 1089 common.KERNEL: kernel_part_size, 1090 common.ROOTFS: rootfs_part_size}} 1091 1092 should_fail = (fail_wrong_payload_type or fail_mismatched_block_size or 1093 fail_mismatched_metadata_size or fail_excess_data or 1094 fail_rootfs_part_size_exceeded or 1095 fail_kernel_part_size_exceeded) 1096 if should_fail: 1097 self.assertRaises(PayloadError, payload_checker.Run, **kwargs2) 1098 else: 1099 self.assertIsNone(payload_checker.Run(**kwargs2)) 1100 1101 1102# This implements a generic API, hence the occasional unused args. 1103# pylint: disable=W0613 1104def ValidateCheckOperationTest(op_type_name, allow_unhashed, 1105 fail_src_extents, fail_dst_extents, 1106 fail_mismatched_data_offset_length, 1107 fail_missing_dst_extents, fail_src_length, 1108 fail_dst_length, fail_data_hash, 1109 fail_prev_data_offset, fail_bad_minor_version): 1110 """Returns True iff the combination of arguments represents a valid test.""" 1111 op_type = _OpTypeByName(op_type_name) 1112 1113 # REPLACE/REPLACE_BZ/REPLACE_XZ operations don't read data from src 1114 # partition. They are compatible with all valid minor versions, so we don't 1115 # need to check that. 1116 if (op_type in (common.OpType.REPLACE, common.OpType.REPLACE_BZ, 1117 common.OpType.REPLACE_XZ) and (fail_src_extents or 1118 fail_src_length or 1119 fail_bad_minor_version)): 1120 return False 1121 1122 # SOURCE_COPY operation does not carry data. 1123 if (op_type == common.OpType.SOURCE_COPY and ( 1124 fail_mismatched_data_offset_length or fail_data_hash or 1125 fail_prev_data_offset)): 1126 return False 1127 1128 return True 1129 1130 1131def TestMethodBody(run_method_name, run_dargs): 1132 """Returns a function that invokes a named method with named arguments.""" 1133 return lambda self: getattr(self, run_method_name)(**run_dargs) 1134 1135 1136def AddParametricTests(tested_method_name, arg_space, validate_func=None): 1137 """Enumerates and adds specific parametric tests to PayloadCheckerTest. 1138 1139 This function enumerates a space of test parameters (defined by arg_space), 1140 then binds a new, unique method name in PayloadCheckerTest to a test function 1141 that gets handed the said parameters. This is a preferable approach to doing 1142 the enumeration and invocation during the tests because this way each test is 1143 treated as a complete run by the unittest framework, and so benefits from the 1144 usual setUp/tearDown mechanics. 1145 1146 Args: 1147 tested_method_name: Name of the tested PayloadChecker method. 1148 arg_space: A dictionary containing variables (keys) and lists of values 1149 (values) associated with them. 1150 validate_func: A function used for validating test argument combinations. 1151 """ 1152 for value_tuple in itertools.product(*iter(arg_space.values())): 1153 run_dargs = dict(zip(iter(arg_space.keys()), value_tuple)) 1154 if validate_func and not validate_func(**run_dargs): 1155 continue 1156 run_method_name = 'Do%sTest' % tested_method_name 1157 test_method_name = 'test%s' % tested_method_name 1158 for arg_key, arg_val in run_dargs.items(): 1159 if arg_val or isinstance(arg_val, int): 1160 test_method_name += '__%s=%s' % (arg_key, arg_val) 1161 setattr(PayloadCheckerTest, test_method_name, 1162 TestMethodBody(run_method_name, run_dargs)) 1163 1164 1165def AddAllParametricTests(): 1166 """Enumerates and adds all parametric tests to PayloadCheckerTest.""" 1167 # Add all _CheckElem() test cases. 1168 AddParametricTests('AddElem', 1169 {'linebreak': (True, False), 1170 'indent': (0, 1, 2), 1171 'convert': (str, lambda s: s[::-1]), 1172 'is_present': (True, False), 1173 'is_mandatory': (True, False), 1174 'is_submsg': (True, False)}) 1175 1176 # Add all _Add{Mandatory,Optional}Field tests. 1177 AddParametricTests('AddField', 1178 {'is_mandatory': (True, False), 1179 'linebreak': (True, False), 1180 'indent': (0, 1, 2), 1181 'convert': (str, lambda s: s[::-1]), 1182 'is_present': (True, False)}) 1183 1184 # Add all _Add{Mandatory,Optional}SubMsg tests. 1185 AddParametricTests('AddSubMsg', 1186 {'is_mandatory': (True, False), 1187 'is_present': (True, False)}) 1188 1189 # Add all _CheckManifest() test cases. 1190 AddParametricTests('CheckManifest', 1191 {'fail_mismatched_block_size': (True, False), 1192 'fail_bad_sigs': (True, False), 1193 'fail_mismatched_oki_ori': (True, False), 1194 'fail_bad_oki': (True, False), 1195 'fail_bad_ori': (True, False), 1196 'fail_bad_nki': (True, False), 1197 'fail_bad_nri': (True, False), 1198 'fail_old_kernel_fs_size': (True, False), 1199 'fail_old_rootfs_fs_size': (True, False), 1200 'fail_new_kernel_fs_size': (True, False), 1201 'fail_new_rootfs_fs_size': (True, False)}) 1202 1203 # Add all _CheckOperation() test cases. 1204 AddParametricTests('CheckOperation', 1205 {'op_type_name': ('REPLACE', 'REPLACE_BZ', 'REPLACE_XZ', 1206 'SOURCE_COPY', 'SOURCE_BSDIFF', 1207 'PUFFDIFF', 'BROTLI_BSDIFF'), 1208 'allow_unhashed': (True, False), 1209 'fail_src_extents': (True, False), 1210 'fail_dst_extents': (True, False), 1211 'fail_mismatched_data_offset_length': (True, False), 1212 'fail_missing_dst_extents': (True, False), 1213 'fail_src_length': (True, False), 1214 'fail_dst_length': (True, False), 1215 'fail_data_hash': (True, False), 1216 'fail_prev_data_offset': (True, False), 1217 'fail_bad_minor_version': (True, False)}, 1218 validate_func=ValidateCheckOperationTest) 1219 1220 # Add all _CheckOperations() test cases. 1221 AddParametricTests('CheckOperations', 1222 {'fail_nonexhaustive_full_update': (True, False)}) 1223 1224 # Add all _CheckOperations() test cases. 1225 AddParametricTests('CheckSignatures', 1226 {'fail_empty_sigs_blob': (True, False), 1227 'fail_sig_missing_fields': (True, False), 1228 'fail_unknown_sig_version': (True, False), 1229 'fail_incorrect_sig': (True, False)}) 1230 1231 # Add all _CheckManifestMinorVersion() test cases. 1232 AddParametricTests('CheckManifestMinorVersion', 1233 {'minor_version': (None, 0, 2, 3, 4, 5, 555), 1234 'payload_type': (checker._TYPE_FULL, 1235 checker._TYPE_DELTA)}) 1236 1237 # Add all Run() test cases. 1238 AddParametricTests('Run', 1239 {'rootfs_part_size_provided': (True, False), 1240 'kernel_part_size_provided': (True, False), 1241 'fail_wrong_payload_type': (True, False), 1242 'fail_invalid_block_size': (True, False), 1243 'fail_mismatched_metadata_size': (True, False), 1244 'fail_mismatched_block_size': (True, False), 1245 'fail_excess_data': (True, False), 1246 'fail_rootfs_part_size_exceeded': (True, False), 1247 'fail_kernel_part_size_exceeded': (True, False)}) 1248 1249 1250if __name__ == '__main__': 1251 AddAllParametricTests() 1252 unittest.main() 1253