Projects STRLCPY criu Commits a0cc95c0
🤬
  • lib/py: reduce code duplication

    Refactor lib/py/images/images.py to reduce code duplication
    by extracting repetitive code into helper functions and
    private methods. This improves code readability and maintainability,
    as well as reducing the risk of bugs caused by duplicated code.
    Additionally, in Makefile, lib/py/images/images.py is added to the
    list of files to run by flake8 during CI.
    
    Fixes: #340
    
    Signed-off-by: Kouame Behouba Manasse <[email protected]>
  • Loading...
  • a0cc95c0
    1 parent 85b5c1e4
  • ■ ■ ■ ■ ■
    Makefile
    skipped 416 lines
    417 417   flake8 --config=scripts/flake8.cfg test/inhfd/*.py
    418 418   flake8 --config=scripts/flake8.cfg test/others/rpc/config_file.py
    419 419   flake8 --config=scripts/flake8.cfg lib/py/images/pb2dict.py
     420 + flake8 --config=scripts/flake8.cfg lib/py/images/images.py
    420 421   flake8 --config=scripts/flake8.cfg scripts/criu-ns
    421 422   flake8 --config=scripts/flake8.cfg scripts/crit-setup.py
    422 423   flake8 --config=scripts/flake8.cfg coredump/
    skipped 48 lines
  • ■ ■ ■ ■ ■ ■
    lib/py/images/images.py
    skipped 68 lines
    69 69   self.magic = magic
    70 70   
    71 71   
     72 +def decode_base64_data(data):
     73 + """A helper function to decode base64 data."""
     74 + if (sys.version_info > (3, 0)):
     75 + return base64.decodebytes(str.encode(data))
     76 + else:
     77 + return base64.decodebytes(data)
     78 + 
     79 + 
     80 +def write_base64_data(f, data):
     81 + """A helper function to write base64 encoded data to a file."""
     82 + if (sys.version_info > (3, 0)):
     83 + f.write(base64.decodebytes(str.encode(data)))
     84 + else:
     85 + f.write(base64.decodebytes(data))
     86 + 
     87 + 
    72 88  # Generic class to handle loading/dumping criu images entries from/to bin
    73 89  # format to/from dict(json).
    74 90  class entry_handler:
    skipped 210 lines
    285 301   size = len(pb_str)
    286 302   f.write(struct.pack('i', size))
    287 303   f.write(pb_str)
    288  - if (sys.version_info > (3, 0)):
    289  - f.write(base64.decodebytes(str.encode(item['extra'])))
    290  - else:
    291  - f.write(base64.decodebytes(item['extra']))
     304 + write_base64_data(f, item['extra'])
    292 305   else:
    293  - if (sys.version_info > (3, 0)):
    294  - f.write(base64.decodebytes(str.encode(item['extra'])))
    295  - else:
    296  - f.write(base64.decodebytes(item['extra']))
     306 + write_base64_data(f, item['extra'])
    297 307   
    298 308   def dumps(self, entries):
    299 309   f = io.BytesIO('')
    skipped 14 lines
    314 324   return base64.encodebytes(data).decode('utf-8')
    315 325   
    316 326   def dump(self, extra, f, pload):
    317  - if (sys.version_info > (3, 0)):
    318  - data = base64.decodebytes(str.encode(extra))
    319  - else:
    320  - data = base64.decodebytes(extra)
     327 + data = decode_base64_data(extra)
    321 328   f.write(data)
    322 329   
    323 330   def skip(self, f, pload):
    skipped 8 lines
    332 339   return base64.encodebytes(data).decode('utf-8')
    333 340   
    334 341   def dump(self, extra, f, _unused):
    335  - if (sys.version_info > (3, 0)):
    336  - data = base64.decodebytes(str.encode(extra))
    337  - else:
    338  - data = base64.decodebytes(extra)
     342 + data = decode_base64_data(extra)
    339 343   f.write(data)
    340 344   
    341 345   def skip(self, f, pload):
    skipped 14 lines
    356 360   return d
    357 361   
    358 362   def dump(self, extra, f, _unused):
    359  - if (sys.version_info > (3, 0)):
    360  - inq = base64.decodebytes(str.encode(extra['inq']))
    361  - outq = base64.decodebytes(str.encode(extra['outq']))
    362  - else:
    363  - inq = base64.decodebytes(extra['inq'])
    364  - outq = base64.decodebytes(extra['outq'])
     363 + inq = decode_base64_data(extra['inq'])
     364 + outq = decode_base64_data(extra['outq'])
    365 365   
    366 366   f.write(inq)
    367 367   f.write(outq)
    skipped 1 lines
    369 369   def skip(self, f, pbuff):
    370 370   f.seek(0, os.SEEK_END)
    371 371   return pbuff.inq_len + pbuff.outq_len
     372 + 
    372 373   
    373 374  class bpfmap_data_extra_handler:
    374 375   def load(self, f, pload):
    skipped 8 lines
    383 384   def skip(self, f, pload):
    384 385   f.seek(pload.bytes, os.SEEK_CUR)
    385 386   return pload.bytes
     387 + 
    386 388   
    387 389  class ipc_sem_set_handler:
    388 390   def load(self, f, pbuff):
    389 391   entry = pb2dict.pb2dict(pbuff)
    390 392   size = sizeof_u16 * entry['nsems']
    391 393   rounded = round_up(size, sizeof_u64)
    392  - s = array.array('H')
    393  - if s.itemsize != sizeof_u16:
    394  - raise Exception("Array size mismatch")
     394 + s = self._get_sem_array()
    395 395   s.frombytes(f.read(size))
    396 396   f.seek(rounded - size, 1)
    397 397   return s.tolist()
    skipped 2 lines
    400 400   entry = pb2dict.pb2dict(pbuff)
    401 401   size = sizeof_u16 * entry['nsems']
    402 402   rounded = round_up(size, sizeof_u64)
    403  - s = array.array('H')
    404  - if s.itemsize != sizeof_u16:
    405  - raise Exception("Array size mismatch")
     403 + s = self._get_sem_array()
    406 404   s.fromlist(extra)
    407 405   if len(s) != entry['nsems']:
    408 406   raise Exception("Number of semaphores mismatch")
    skipped 6 lines
    415 413   f.seek(round_up(size, sizeof_u64), os.SEEK_CUR)
    416 414   return size
    417 415   
     416 + def _get_sem_array(self):
     417 + s = array.array('H')
     418 + if s.itemsize != sizeof_u16:
     419 + raise Exception("Array size mismatch")
     420 + return s
     421 + 
    418 422   
    419 423  class ipc_msg_queue_handler:
    420 424   def load(self, f, pbuff):
    421  - entry = pb2dict.pb2dict(pbuff)
    422  - messages = []
    423  - for x in range(0, entry['qnum']):
    424  - buf = f.read(4)
    425  - if len(buf) == 0:
    426  - break
    427  - size, = struct.unpack('i', buf)
    428  - msg = pb.ipc_msg()
    429  - msg.ParseFromString(f.read(size))
    430  - rounded = round_up(msg.msize, sizeof_u64)
    431  - data = f.read(msg.msize)
    432  - f.seek(rounded - msg.msize, 1)
    433  - messages.append(pb2dict.pb2dict(msg))
    434  - messages.append(base64.encodebytes(data).decode('utf-8'))
     425 + messages, _ = self._read_messages(f, pbuff)
    435 426   return messages
    436 427   
    437 428   def dump(self, extra, f, pbuff):
    skipped 5 lines
    443 434   f.write(struct.pack('i', size))
    444 435   f.write(msg_str)
    445 436   rounded = round_up(msg.msize, sizeof_u64)
    446  - if (sys.version_info > (3, 0)):
    447  - data = base64.decodebytes(str.encode(extra[i + 1]))
    448  - else:
    449  - data = base64.decodebytes(extra[i + 1])
     437 + data = decode_base64_data(extra[i + 1])
    450 438   f.write(data[:msg.msize])
    451 439   f.write(b'\0' * (rounded - msg.msize))
    452 440   
    453 441   def skip(self, f, pbuff):
     442 + _, pl_len = self._read_messages(f, pbuff, skip_data=True)
     443 + return pl_len
     444 + 
     445 + def _read_messages(self, f, pbuff, skip_data=False):
    454 446   entry = pb2dict.pb2dict(pbuff)
     447 + messages = []
    455 448   pl_len = 0
    456 449   for x in range(0, entry['qnum']):
    457 450   buf = f.read(4)
    skipped 3 lines
    461 454   msg = pb.ipc_msg()
    462 455   msg.ParseFromString(f.read(size))
    463 456   rounded = round_up(msg.msize, sizeof_u64)
    464  - f.seek(rounded, os.SEEK_CUR)
    465 457   pl_len += size + msg.msize
    466 458   
    467  - return pl_len
     459 + if skip_data:
     460 + f.seek(rounded, os.SEEK_CUR)
     461 + else:
     462 + data = f.read(msg.msize)
     463 + f.seek(rounded - msg.msize, 1)
     464 + messages.append(pb2dict.pb2dict(msg))
     465 + messages.append(base64.encodebytes(data).decode('utf-8'))
     466 + 
     467 + return messages, pl_len
    468 468   
    469 469   
    470 470  class ipc_shm_handler:
    skipped 89 lines
    560 560   'MEMFD_INODE': entry_handler(pb.memfd_inode_entry),
    561 561   'BPFMAP_FILE': entry_handler(pb.bpfmap_file_entry),
    562 562   'BPFMAP_DATA': entry_handler(pb.bpfmap_data_entry,
    563  - bpfmap_data_extra_handler()),
     563 + bpfmap_data_extra_handler()),
    564 564   'APPARMOR': entry_handler(pb.apparmor_entry),
    565 565  }
    566 566   
    skipped 7 lines
    574 574   
    575 575   try:
    576 576   m = magic.by_val[img_magic]
    577  - except:
     577 + except Exception:
    578 578   raise MagicException(img_magic)
    579 579   
    580 580   try:
    581 581   handler = handlers[m]
    582  - except:
     582 + except Exception:
    583 583   raise Exception("No handler found for image with magic " + m)
    584 584   
    585 585   return m, handler
    skipped 55 lines
    641 641   
    642 642   try:
    643 643   handler = handlers[m]
    644  - except:
     644 + except Exception:
    645 645   raise Exception("No handler found for image with such magic")
    646 646   
    647 647   handler.dump(img['entries'], f)
    skipped 11 lines
Please wait...
Page is in error, reload to recover