diff --git a/tools/util/bitint.py b/tools/util/bitint.py new file mode 100644 index 0000000..fde11aa --- /dev/null +++ b/tools/util/bitint.py @@ -0,0 +1,80 @@ +# Copyright (c) Kuba Szczodrzyński 2022-06-10. + +from typing import List, Tuple, Union + +from tools.util.intbin import uintmax +from tools.util.obj import SliceLike, slice2int + + +def bitcat(*vars: Tuple[Union["BitInt", int], SliceLike]) -> int: + """Concat all 'vars' denoted in a (value, slice) format into a bitstring.""" + out = 0 + for val, sl in vars: + if not isinstance(val, BitInt): + val = BitInt(val) + (start, stop) = slice2int(sl) + out <<= start - stop + 1 + out |= val[start:stop] + return out + + +def bitcatraw(*vars: Tuple[int, int]) -> int: + """Concat all 'vars' denoted in a (value, bitwidth) format into a bitstring.""" + out = 0 + for val, bits in vars: + out <<= bits + out |= val + return out + + +class BitInt(int): + """ + Wrapper for int supporting slice reading and assignment of + individual bits (counting from LSB to MSB, like '7:0'). + """ + + value: int = None + + def __init__(self, value: int) -> None: + self.value = value + + def __getitem__(self, key): + if self.value is None: + self.value = self + # for best performance, slice2int() type checking was disabled + if isinstance(key, int): + return (self.value >> key) % 2 + # (start, stop) = slice2int(key) + return (self.value >> key.stop) & uintmax(key.start - key.stop + 1) + + def __setitem__(self, key, value): + if self.value is None: + self.value = self + (start, stop) = slice2int(key) + + if value > uintmax(start - stop + 1): + raise ValueError("value is too big") + + tmp = self.value & ~uintmax(start + 1) + tmp |= self.value & uintmax(stop) + tmp |= value << stop + self.value = tmp + + def rep(self, n: int, sl: Union[SliceLike, List[SliceLike]]) -> int: + """Construct a bitstring from 'sl' (being a single slice or a list) + repeated 'n' times.""" + if isinstance(sl, list): + return self.cat(*(sl * n)) + return self.cat(*([sl] * n)) + + def cat(self, *slices: SliceLike) -> int: + """Construct a bitstring from this BitInt's parts denoted by 'slices'.""" + out = 0 + for sl in slices: + (start, stop) = slice2int(sl) + out <<= start - stop + 1 + out |= self[start:stop] + return out + + def __int__(self) -> int: + return self.value or self diff --git a/tools/util/bkcrypto.py b/tools/util/bkcrypto.py new file mode 100644 index 0000000..8b90658 --- /dev/null +++ b/tools/util/bkcrypto.py @@ -0,0 +1,167 @@ +# Copyright (c) Kuba Szczodrzyński 2022-06-10. + +from typing import List, Tuple + +from tools.util.bitint import BitInt, bitcatraw + + +def pn15(addr: int) -> int: + # wire [15:0] pn_tmp = {addr[6:0], addr[15:7]} ^ {16'h6371 & {4{addr[8:5]}}}; + a = ((addr % 0x80) * 0x200) + ((addr // 0x80) % 0x200) + b = (addr // 0x20) % 0x10 + c = 0x6371 & (b * 0x1111) + return a ^ c + + +def pn16(addr: int) -> int: + # wire [16:0] pn_tmp = {addr[9:0], addr[16:10]} ^ {17'h13659 & {addr[4],{4{addr[1],addr[5],addr[9],addr[13]}}}}; + a = ((addr % 0x400) * 0x80) + ((addr // 0x400) % 0x80) + b = (addr // 0x2000) % 2 + b += ((addr // 0x200) % 2) * 2 + b += ((addr // 0x20) % 2) * 4 + b += ((addr // 0x2) % 2) * 8 + c = (addr // 0x10) % 2 + d = 0x13659 & (c * 0x10000 + b * 0x1111) + return a ^ d + + +def pn32(addr: int) -> int: + # wire [31:0] pn_tmp = {addr[14:0], addr[31:15]} ^ {32'hE519A4F1 & {8{addr[5:2]}}}; + a = ((addr % 0x8000) * 0x20000) + ((addr // 0x8000) % 0x20000) + b = (addr // 0x4) % 0x10 + c = 0xE519A4F1 & (b * 0x11111111) + return a ^ c + + +class BekenCrypto: + # translated from https://github.com/ghsecuritylab/tysdk_for_bk7231/blob/master/toolchain/encrypt_crc/abc.c + coef0: BitInt + coef1_mix: int + coef1_hi16: int + bypass: bool = False + pn15_args: List[slice] = None + pn16_args: slice = None + pn32_args: Tuple[int, int] = None + random: int = 0 + + def __init__(self, coeffs: List[BitInt]) -> None: + (self.coef0, coef1, coef2, coef3) = coeffs + + # wire g_bypass = (coef3[31:24] == 8'hFF) | (coef3[31:24] == 8'h00); + self.bypass = coef3[31:24] in [0x00, 0xFF] + if self.bypass: + return + + # wire pn16_bit = coef3[4]; + # wire[16:0] pn16_addr = pn16_A ^ {coef1[15:8], pn16_bit, coef1[7:0]}; + self.coef1_mix = bitcatraw((coef1[15:8], 8), (coef3[4], 1), (coef1[7:0], 8)) + self.coef1_hi16 = coef1[31:16] + + # wire pn15_bps = g_bypass | coef3[0]; + pn15_bps = coef3[0] + # wire pn16_bps = g_bypass | coef3[1]; + pn16_bps = coef3[1] + # wire pn32_bps = g_bypass | coef3[2]; + pn32_bps = coef3[2] + # wire rand_bps = g_bypass | coef3[3]; + rand_bps = coef3[3] + + if coef3[3:0] == 0xF: + self.bypass = True + return + + if not pn15_bps: + # wire[1:0] pn15_sel = coef3[ 6: 5]; + pn15_sel = coef3[6:5] + # wire[15:0] pn15_A = (pn15_sel == 0) ? ({addr[31:24], addr[23:16]} ^ {addr[15:8], addr[ 7:0]}) : + # (pn15_sel == 1) ? ({addr[31:24], addr[23:16]} ^ {addr[ 7:0], addr[15:8]}) : + # (pn15_sel == 2) ? ({addr[23:16], addr[31:24]} ^ {addr[15:8], addr[ 7:0]}) : + # ({addr[23:16], addr[31:24]} ^ {addr[ 7:0], addr[15:8]}); + if pn15_sel == 0: + self.pn15_args = [ + slice(31, 24), + slice(23, 16), + slice(15, 8), + slice(7, 0), + ] + elif pn15_sel == 1: + self.pn15_args = [ + slice(31, 24), + slice(23, 16), + slice(7, 0), + slice(15, 8), + ] + elif pn15_sel == 2: + self.pn15_args = [ + slice(23, 16), + slice(31, 24), + slice(15, 8), + slice(7, 0), + ] + else: + self.pn15_args = [ + slice(23, 16), + slice(31, 24), + slice(7, 0), + slice(15, 8), + ] + + if not pn16_bps: + # wire[1:0] pn16_sel = coef3[ 9: 8]; + pn16_sel = coef3[9:8] + # wire[16:0] pn16_A = (pn16_sel == 0) ? addr[16:0] : + # (pn16_sel == 1) ? addr[17:1] : + # (pn16_sel == 2) ? addr[18:2] : + # addr[19:3]; + self.pn16_args = slice(16 + pn16_sel, pn16_sel) + + if not pn32_bps: + # wire[1:0] pn32_sel = coef3[12:11]; + pn32_sel = coef3[12:11] + # wire[31:0] pn32_A = (pn32_sel == 0) ? addr[31:0] : + # (pn32_sel == 1) ? {addr[ 7:0], addr[31: 8]} : + # (pn32_sel == 2) ? {addr[15:0], addr[31:16]} : + # {addr[23:0], addr[31:24]}; + PN32_SHIFTS = ( + (0, 0), + (2**8, 2**24), + (2**16, 2**16), + (2**24, 2**8), + ) + self.pn32_args = PN32_SHIFTS[pn32_sel] + + # wire[31:0] random = rand_bps ? 32'h00000000 : coef2[31:0]; + self.random = 0 if rand_bps else coef2 + + def encrypt_u32(self, addr: int, data: int) -> int: + if self.bypass: + return data + addr = BitInt(addr) + + pn15_v = 0 + pn16_v = 0 + pn32_v = 0 + + if self.pn15_args: + pn15_a = (addr[self.pn15_args[0]] * 0x100) + addr[self.pn15_args[1]] + pn15_b = (addr[self.pn15_args[2]] * 0x100) + addr[self.pn15_args[3]] + pn15_A = pn15_a ^ pn15_b + # wire[15:0] pn15_addr = pn15_A ^ coef1[31:16]; + pn15_addr = pn15_A ^ self.coef1_hi16 + pn15_v = pn15(pn15_addr) + + if self.pn16_args: + pn16_A = addr[self.pn16_args] + # wire[16:0] pn16_addr = pn16_A ^ {coef1[15:8], pn16_bit, coef1[7:0]}; + pn16_addr = pn16_A ^ self.coef1_mix + pn16_v = pn16(pn16_addr) + + if self.pn32_args: + pn32_A = (addr // self.pn32_args[0]) + (addr * self.pn32_args[1]) + # wire[31:0] pn32_addr = pn32_A ^ coef0[31:0]; + pn32_addr = pn32_A ^ self.coef0 + pn32_v = pn32(pn32_addr) + + # assign pnout = pn32[31:0] ^ {pn15[15:0], pn16[15:0]} ^ random[31:0]; + pnout = pn32_v ^ ((pn15_v * 0x10000) + (pn16_v % 0x10000)) ^ self.random + return data ^ pnout diff --git a/tools/util/bkutil.py b/tools/util/bkutil.py new file mode 100644 index 0000000..51d4ca3 --- /dev/null +++ b/tools/util/bkutil.py @@ -0,0 +1,314 @@ +# Copyright (c) Kuba Szczodrzyński 2022-06-10. + +import sys +from os.path import dirname, join + +sys.path.append(join(dirname(__file__), "..", "..")) + +from argparse import ArgumentParser, FileType +from binascii import crc32 +from dataclasses import dataclass, field +from enum import IntFlag +from io import SEEK_SET, FileIO +from os import stat +from struct import Struct +from time import time +from typing import Union + +from tools.util.bitint import BitInt +from tools.util.bkcrypto import BekenCrypto +from tools.util.crc16 import CRC16 +from tools.util.fileio import readbin, writebin +from tools.util.intbin import ( + ByteGenerator, + align_up, + betoint, + biniter, + fileiter, + geniter, + inttobe16, + inttole32, + letoint, + pad_data, + pad_up, +) + + +class OTAAlgorithm(IntFlag): + NONE = 0 + CRYPT_XOR = 1 + CRYPT_AES256 = 2 + COMPRESS_GZIP = 256 + COMPRESS_QUICKLZ = 512 + COMPRESS_FASTLZ = 768 + + +@dataclass +class RBL: + ota_algo: OTAAlgorithm = OTAAlgorithm.NONE + timestamp: float = field(default_factory=time) + name: Union[str, bytes] = "app" + version: Union[str, bytes] = "1.00" + sn: Union[str, bytes] = "0" * 23 + data_crc: int = 0 + data_hash: int = 0x811C9DC5 # https://github.com/znerol/py-fnvhash/blob/master/fnvhash/__init__.py + raw_size: int = 0 + data_size: int = 0 + container_size: int = 0 + has_part_table: bool = False + + def update(self, data: bytes): + self.data_crc = crc32(data, self.data_crc) + for byte in data: + if self.data_size < self.raw_size: + self.data_hash ^= byte + self.data_hash *= 0x01000193 + self.data_hash %= 0x100000000 + self.data_size += 1 + + def serialize(self) -> bytes: + if isinstance(self.name, str): + self.name = self.name.encode() + if isinstance(self.version, str): + self.version = self.version.encode() + if isinstance(self.sn, str): + self.sn = self.sn.encode() + # based on https://github.com/khalednassar/bk7231tools/blob/main/bk7231tools/analysis/rbl.py + struct = Struct("<4sII16s24s24sIIII") # without header CRC + rbl = struct.pack( + b"RBL\x00", + self.ota_algo, + int(self.timestamp), + pad_data(self.name, 16, 0x00), + pad_data(self.version, 24, 0x00), + pad_data(self.sn, 24, 0x00), + self.data_crc, + self.data_hash, + self.raw_size, + self.data_size, + ) + return rbl + inttole32(crc32(rbl)) + + @classmethod + def deserialize(cls, data: bytes) -> "RBL": + crc_found = letoint(data[-4:]) + data = data[:-4] + crc_expected = crc32(data) + if crc_expected != crc_found: + raise ValueError( + f"Invalid RBL CRC (expected {crc_expected:X}, found {crc_found:X})" + ) + struct = Struct(" None: + if coeffs: + if isinstance(coeffs, str): + coeffs = bytes.fromhex(coeffs) + if len(coeffs) != 16: + raise ValueError( + f"Invalid length of encryption coefficients: {len(coeffs)}" + ) + coeffs = list(map(BitInt, map(betoint, biniter(coeffs, 4)))) + self.crypto = BekenCrypto(coeffs) + + def crc(self, data: ByteGenerator) -> ByteGenerator: + for block in geniter(data, 32): + crc = CRC16.CMS.calc(block) + yield block + yield inttobe16(crc) + + def uncrc(self, data: ByteGenerator, check: bool = True) -> ByteGenerator: + for block in geniter(data, 34): + if check: + crc = CRC16.CMS.calc(block[0:32]) + crc_found = betoint(block[32:34]) + if crc != crc_found: + print(f"CRC invalid: expected={crc:X}, found={crc_found:X}") + return + yield block[0:32] + + def crypt(self, addr: int, data: ByteGenerator) -> ByteGenerator: + for word in geniter(data, 4): + word = letoint(word) + word = self.crypto.encrypt_u32(addr, word) + word = inttole32(word) + yield word + addr += 4 + + def package(self, f: FileIO, addr: int, size: int, rbl: RBL) -> ByteGenerator: + if not rbl.container_size: + raise ValueError("RBL must have a total size when packaging") + crc_total = 0 + + # when to stop reading input data + data_end = size + if rbl.has_part_table: + data_end = size - 0xC0 # do not encrypt the partition table + + # set RBL size including one 16-byte padding + rbl.raw_size = align_up(size + 16, 32) + 16 + + # encrypt the input file, padded to 32 bytes + data_crypt_gen = self.crypt( + addr, fileiter(f, size=32, padding=0xFF, count=data_end) + ) + # iterate over encrypted 32-byte blocks + for block in geniter(data_crypt_gen, 32): + # add CRC16 and yield + yield from self.crc(block) + crc_total += 2 + rbl.update(block) + + # temporary buffer for small-size operations + buf = b"\xff" * 16 # add 16 bytes of padding + + if rbl.has_part_table: + # add an unencrypted partition table + buf += f.read(0xC0) + + # update RBL + rbl.update(buf) + # add last padding with different values + rbl.update(b"\x10" * 16) + + # add last padding with normal values + buf += b"\xff" * 16 + # yield the temporary buffer + yield from self.crc(buf) + crc_total += 2 * (len(buf) // 32) + + # pad the entire container with 0xFF, excluding RBL and its CRC16 + pad_size = pad_up(rbl.data_size + crc_total, rbl.container_size) - 102 + for _ in range(pad_size): + yield b"\xff" + + # yield RBL with CRC16 + yield from self.crc(rbl.serialize()) + + +def auto_int(x): + return int(x, 0) + + +def add_common_args(parser): + parser.add_argument( + "coeffs", type=str, help="Encryption coefficients (hex string, 32 chars)" + ) + parser.add_argument("input", type=FileType("rb"), help="Input file") + parser.add_argument("output", type=FileType("wb"), help="Output file") + parser.add_argument("addr", type=auto_int, help="Memory address (dec/hex)") + + +if __name__ == "__main__": + parser = ArgumentParser(description="Encrypt/decrypt Beken firmware binaries") + sub = parser.add_subparsers(dest="action", required=True) + + encrypt = sub.add_parser("encrypt", help="Encrypt binary files without packaging") + add_common_args(encrypt) + encrypt.add_argument("-c", "--crc", help="Include CRC16", action="store_true") + + decrypt = sub.add_parser("decrypt", description="Decrypt unpackaged binary files") + add_common_args(decrypt) + decrypt.add_argument( + "-C", + "--no-crc-check", + help="Do not check CRC16 (if present)", + action="store_true", + ) + + package = sub.add_parser( + "package", description="Package raw binary files as RBL containers" + ) + add_common_args(package) + package.add_argument("size", type=auto_int, help="RBL total size (dec/hex)") + package.add_argument( + "-n", + "--name", + type=str, + help="Firmware name (default: app)", + default="app", + required=False, + ) + package.add_argument( + "-v", + "--version", + type=str, + help="Firmware version (default: 1.00)", + default="1.00", + required=False, + ) + + unpackage = sub.add_parser( + "unpackage", description="Unpackage a single RBL container" + ) + add_common_args(unpackage) + unpackage.add_argument( + "offset", type=auto_int, help="Offset in input file (dec/hex)" + ) + unpackage.add_argument("size", type=auto_int, help="RBL total size (dec/hex)") + + args = parser.parse_args() + bk = BekenBinary(args.coeffs) + f: FileIO = args.input + size = stat(args.input.name).st_size + start = time() + + if args.action == "encrypt": + print(f"Encrypting '{f.name}' ({size} bytes)") + if args.crc: + print(f" - calculating 32-byte block CRC16...") + gen = bk.crc(bk.crypt(args.addr, f)) + else: + print(f" - as raw binary, without CRC16...") + gen = bk.crypt(args.addr, f) + + if args.action == "decrypt": + print(f"Decrypting '{f.name}' ({size} bytes)") + if size % 34 == 0: + if args.no_crc_check: + print(f" - has CRC16, skipping checks...") + else: + print(f" - has CRC16, checking...") + gen = bk.crypt(args.addr, bk.uncrc(f, check=not args.no_crc_check)) + elif size % 4 != 0: + raise ValueError("Input file has invalid length") + else: + print(f" - raw binary, no CRC") + gen = bk.crypt(args.addr, f) + + if args.action == "package": + print(f"Packaging {args.name} '{f.name}' for memory address 0x{args.addr:X}") + rbl = RBL(name=args.name, version=args.version) + if args.name == "bootloader": + rbl.has_part_table = True + print(f" - in bootloader mode; partition table unencrypted") + rbl.container_size = args.size + print(f" - container size: 0x{args.size:X}") + gen = bk.package(f, args.addr, size, rbl) + + if args.action == "unpackage": + print(f"Unpackaging '{f.name}' (at 0x{args.offset:X}, size 0x{args.size:X})") + f.seek(args.offset + args.size - 102, SEEK_SET) + rbl = f.read(102) + rbl = b"".join(bk.uncrc(rbl)) + rbl = RBL.deserialize(rbl) + print(f" - found '{rbl.name}' ({rbl.version}), size {rbl.data_size}") + f.seek(0, SEEK_SET) + crc_size = (rbl.data_size - 16) // 32 * 34 + gen = bk.crypt(args.addr, bk.uncrc(fileiter(f, 32, 0xFF, crc_size))) + + written = 0 + for data in gen: + args.output.write(data) + written += len(data) + print(f" - wrote {written} bytes in {time()-start:.3f} s") diff --git a/tools/util/intbin.py b/tools/util/intbin.py index 9b0b2f3..209e865 100644 --- a/tools/util/intbin.py +++ b/tools/util/intbin.py @@ -1,5 +1,10 @@ # Copyright (c) Kuba Szczodrzyński 2022-06-02. +from io import FileIO +from typing import IO, Generator, Union + +ByteGenerator = Generator[bytes, None, None] + def bswap(data: bytes) -> bytes: """Reverse the byte array (big-endian <-> little-endian).""" @@ -137,3 +142,63 @@ def uint32(val): def uintmax(bits: int) -> int: """Get maximum integer size for given bit width.""" return (2**bits) - 1 + + +def biniter(data: bytes, size: int) -> ByteGenerator: + """Iterate over 'data' in 'size'-bytes long chunks, returning + a generator.""" + if len(data) % size != 0: + raise ValueError( + f"Data length must be a multiple of block size ({len(data)} % {size})" + ) + for i in range(0, len(data), size): + yield data[i : i + size] + + +def geniter(gen: Union[ByteGenerator, bytes, IO], size: int) -> ByteGenerator: + """ + Take data from 'gen' and generate 'size'-bytes long chunks. + + If 'gen' is a bytes or IO object, it is wrapped using + biniter() or fileiter(). + """ + if isinstance(gen, bytes): + yield from biniter(gen, size) + return + if isinstance(gen, IO): + yield from fileiter(gen, size) + return + buf = b"" + for part in gen: + if not buf and len(part) == size: + yield part + continue + buf += part + while len(buf) >= size: + yield buf[0:size] + buf = buf[size:] + + +def fileiter( + f: FileIO, size: int, padding: int = 0x00, count: int = 0 +) -> ByteGenerator: + """ + Read data from 'f' and generate 'size'-bytes long chunks. + + Pad incomplete chunks with 'padding' character. + + Read up to 'count' bytes from 'f', if specified. Data is padded + if not on chunk boundary. + """ + read = 0 + while True: + if count and read + size >= count: + yield pad_data(f.read(count % size), size, padding) + return + data = f.read(size) + read += len(data) + if len(data) < size: + # got only part of the block + yield pad_data(data, size, padding) + return + yield data diff --git a/tools/util/obj.py b/tools/util/obj.py index ca35173..9267fa2 100644 --- a/tools/util/obj.py +++ b/tools/util/obj.py @@ -1,7 +1,9 @@ # Copyright (c) Kuba Szczodrzyński 2022-06-02. import json -from typing import Union +from typing import Tuple, Union + +SliceLike = Union[slice, str, int] def merge_dicts(d1, d2, path=None): @@ -27,3 +29,24 @@ def get(data: dict, path: str): return data.get(path, None) key, _, path = path.partition(".") return get(data.get(key, None), path) + + +def slice2int(val: SliceLike) -> Tuple[int, int]: + """Convert a slice-like value (slice, string '7:0' or '3', int '3') + to a tuple of (start, stop).""" + if isinstance(val, int): + return (val, val) + if isinstance(val, slice): + if val.step: + raise ValueError("value must be a slice without step") + if val.start < val.stop: + raise ValueError("start must not be less than stop") + return (val.start, val.stop) + if isinstance(val, str): + if ":" in val: + val = val.split(":") + if len(val) == 2: + return tuple(map(int, val)) + elif val.isnumeric(): + return (int(val), int(val)) + raise ValueError(f"invalid slice format: {val}")