[tools] Add Beken binary crypto utility

This commit is contained in:
Kuba Szczodrzyński
2022-06-11 23:00:00 +02:00
parent c3f2ce57f0
commit dba602a081
5 changed files with 650 additions and 1 deletions

80
tools/util/bitint.py Normal file
View File

@@ -0,0 +1,80 @@
# Copyright (c) Kuba Szczodrzyński 2022-06-10.
from typing import List, Tuple, Union
from tools.util.intbin import uintmax
from tools.util.obj import SliceLike, slice2int
def bitcat(*vars: Tuple[Union["BitInt", int], SliceLike]) -> int:
"""Concat all 'vars' denoted in a (value, slice) format into a bitstring."""
out = 0
for val, sl in vars:
if not isinstance(val, BitInt):
val = BitInt(val)
(start, stop) = slice2int(sl)
out <<= start - stop + 1
out |= val[start:stop]
return out
def bitcatraw(*vars: Tuple[int, int]) -> int:
"""Concat all 'vars' denoted in a (value, bitwidth) format into a bitstring."""
out = 0
for val, bits in vars:
out <<= bits
out |= val
return out
class BitInt(int):
"""
Wrapper for int supporting slice reading and assignment of
individual bits (counting from LSB to MSB, like '7:0').
"""
value: int = None
def __init__(self, value: int) -> None:
self.value = value
def __getitem__(self, key):
if self.value is None:
self.value = self
# for best performance, slice2int() type checking was disabled
if isinstance(key, int):
return (self.value >> key) % 2
# (start, stop) = slice2int(key)
return (self.value >> key.stop) & uintmax(key.start - key.stop + 1)
def __setitem__(self, key, value):
if self.value is None:
self.value = self
(start, stop) = slice2int(key)
if value > uintmax(start - stop + 1):
raise ValueError("value is too big")
tmp = self.value & ~uintmax(start + 1)
tmp |= self.value & uintmax(stop)
tmp |= value << stop
self.value = tmp
def rep(self, n: int, sl: Union[SliceLike, List[SliceLike]]) -> int:
"""Construct a bitstring from 'sl' (being a single slice or a list)
repeated 'n' times."""
if isinstance(sl, list):
return self.cat(*(sl * n))
return self.cat(*([sl] * n))
def cat(self, *slices: SliceLike) -> int:
"""Construct a bitstring from this BitInt's parts denoted by 'slices'."""
out = 0
for sl in slices:
(start, stop) = slice2int(sl)
out <<= start - stop + 1
out |= self[start:stop]
return out
def __int__(self) -> int:
return self.value or self

167
tools/util/bkcrypto.py Normal file
View File

@@ -0,0 +1,167 @@
# Copyright (c) Kuba Szczodrzyński 2022-06-10.
from typing import List, Tuple
from tools.util.bitint import BitInt, bitcatraw
def pn15(addr: int) -> int:
# wire [15:0] pn_tmp = {addr[6:0], addr[15:7]} ^ {16'h6371 & {4{addr[8:5]}}};
a = ((addr % 0x80) * 0x200) + ((addr // 0x80) % 0x200)
b = (addr // 0x20) % 0x10
c = 0x6371 & (b * 0x1111)
return a ^ c
def pn16(addr: int) -> int:
# wire [16:0] pn_tmp = {addr[9:0], addr[16:10]} ^ {17'h13659 & {addr[4],{4{addr[1],addr[5],addr[9],addr[13]}}}};
a = ((addr % 0x400) * 0x80) + ((addr // 0x400) % 0x80)
b = (addr // 0x2000) % 2
b += ((addr // 0x200) % 2) * 2
b += ((addr // 0x20) % 2) * 4
b += ((addr // 0x2) % 2) * 8
c = (addr // 0x10) % 2
d = 0x13659 & (c * 0x10000 + b * 0x1111)
return a ^ d
def pn32(addr: int) -> int:
# wire [31:0] pn_tmp = {addr[14:0], addr[31:15]} ^ {32'hE519A4F1 & {8{addr[5:2]}}};
a = ((addr % 0x8000) * 0x20000) + ((addr // 0x8000) % 0x20000)
b = (addr // 0x4) % 0x10
c = 0xE519A4F1 & (b * 0x11111111)
return a ^ c
class BekenCrypto:
# translated from https://github.com/ghsecuritylab/tysdk_for_bk7231/blob/master/toolchain/encrypt_crc/abc.c
coef0: BitInt
coef1_mix: int
coef1_hi16: int
bypass: bool = False
pn15_args: List[slice] = None
pn16_args: slice = None
pn32_args: Tuple[int, int] = None
random: int = 0
def __init__(self, coeffs: List[BitInt]) -> None:
(self.coef0, coef1, coef2, coef3) = coeffs
# wire g_bypass = (coef3[31:24] == 8'hFF) | (coef3[31:24] == 8'h00);
self.bypass = coef3[31:24] in [0x00, 0xFF]
if self.bypass:
return
# wire pn16_bit = coef3[4];
# wire[16:0] pn16_addr = pn16_A ^ {coef1[15:8], pn16_bit, coef1[7:0]};
self.coef1_mix = bitcatraw((coef1[15:8], 8), (coef3[4], 1), (coef1[7:0], 8))
self.coef1_hi16 = coef1[31:16]
# wire pn15_bps = g_bypass | coef3[0];
pn15_bps = coef3[0]
# wire pn16_bps = g_bypass | coef3[1];
pn16_bps = coef3[1]
# wire pn32_bps = g_bypass | coef3[2];
pn32_bps = coef3[2]
# wire rand_bps = g_bypass | coef3[3];
rand_bps = coef3[3]
if coef3[3:0] == 0xF:
self.bypass = True
return
if not pn15_bps:
# wire[1:0] pn15_sel = coef3[ 6: 5];
pn15_sel = coef3[6:5]
# wire[15:0] pn15_A = (pn15_sel == 0) ? ({addr[31:24], addr[23:16]} ^ {addr[15:8], addr[ 7:0]}) :
# (pn15_sel == 1) ? ({addr[31:24], addr[23:16]} ^ {addr[ 7:0], addr[15:8]}) :
# (pn15_sel == 2) ? ({addr[23:16], addr[31:24]} ^ {addr[15:8], addr[ 7:0]}) :
# ({addr[23:16], addr[31:24]} ^ {addr[ 7:0], addr[15:8]});
if pn15_sel == 0:
self.pn15_args = [
slice(31, 24),
slice(23, 16),
slice(15, 8),
slice(7, 0),
]
elif pn15_sel == 1:
self.pn15_args = [
slice(31, 24),
slice(23, 16),
slice(7, 0),
slice(15, 8),
]
elif pn15_sel == 2:
self.pn15_args = [
slice(23, 16),
slice(31, 24),
slice(15, 8),
slice(7, 0),
]
else:
self.pn15_args = [
slice(23, 16),
slice(31, 24),
slice(7, 0),
slice(15, 8),
]
if not pn16_bps:
# wire[1:0] pn16_sel = coef3[ 9: 8];
pn16_sel = coef3[9:8]
# wire[16:0] pn16_A = (pn16_sel == 0) ? addr[16:0] :
# (pn16_sel == 1) ? addr[17:1] :
# (pn16_sel == 2) ? addr[18:2] :
# addr[19:3];
self.pn16_args = slice(16 + pn16_sel, pn16_sel)
if not pn32_bps:
# wire[1:0] pn32_sel = coef3[12:11];
pn32_sel = coef3[12:11]
# wire[31:0] pn32_A = (pn32_sel == 0) ? addr[31:0] :
# (pn32_sel == 1) ? {addr[ 7:0], addr[31: 8]} :
# (pn32_sel == 2) ? {addr[15:0], addr[31:16]} :
# {addr[23:0], addr[31:24]};
PN32_SHIFTS = (
(0, 0),
(2**8, 2**24),
(2**16, 2**16),
(2**24, 2**8),
)
self.pn32_args = PN32_SHIFTS[pn32_sel]
# wire[31:0] random = rand_bps ? 32'h00000000 : coef2[31:0];
self.random = 0 if rand_bps else coef2
def encrypt_u32(self, addr: int, data: int) -> int:
if self.bypass:
return data
addr = BitInt(addr)
pn15_v = 0
pn16_v = 0
pn32_v = 0
if self.pn15_args:
pn15_a = (addr[self.pn15_args[0]] * 0x100) + addr[self.pn15_args[1]]
pn15_b = (addr[self.pn15_args[2]] * 0x100) + addr[self.pn15_args[3]]
pn15_A = pn15_a ^ pn15_b
# wire[15:0] pn15_addr = pn15_A ^ coef1[31:16];
pn15_addr = pn15_A ^ self.coef1_hi16
pn15_v = pn15(pn15_addr)
if self.pn16_args:
pn16_A = addr[self.pn16_args]
# wire[16:0] pn16_addr = pn16_A ^ {coef1[15:8], pn16_bit, coef1[7:0]};
pn16_addr = pn16_A ^ self.coef1_mix
pn16_v = pn16(pn16_addr)
if self.pn32_args:
pn32_A = (addr // self.pn32_args[0]) + (addr * self.pn32_args[1])
# wire[31:0] pn32_addr = pn32_A ^ coef0[31:0];
pn32_addr = pn32_A ^ self.coef0
pn32_v = pn32(pn32_addr)
# assign pnout = pn32[31:0] ^ {pn15[15:0], pn16[15:0]} ^ random[31:0];
pnout = pn32_v ^ ((pn15_v * 0x10000) + (pn16_v % 0x10000)) ^ self.random
return data ^ pnout

314
tools/util/bkutil.py Normal file
View File

@@ -0,0 +1,314 @@
# Copyright (c) Kuba Szczodrzyński 2022-06-10.
import sys
from os.path import dirname, join
sys.path.append(join(dirname(__file__), "..", ".."))
from argparse import ArgumentParser, FileType
from binascii import crc32
from dataclasses import dataclass, field
from enum import IntFlag
from io import SEEK_SET, FileIO
from os import stat
from struct import Struct
from time import time
from typing import Union
from tools.util.bitint import BitInt
from tools.util.bkcrypto import BekenCrypto
from tools.util.crc16 import CRC16
from tools.util.fileio import readbin, writebin
from tools.util.intbin import (
ByteGenerator,
align_up,
betoint,
biniter,
fileiter,
geniter,
inttobe16,
inttole32,
letoint,
pad_data,
pad_up,
)
class OTAAlgorithm(IntFlag):
NONE = 0
CRYPT_XOR = 1
CRYPT_AES256 = 2
COMPRESS_GZIP = 256
COMPRESS_QUICKLZ = 512
COMPRESS_FASTLZ = 768
@dataclass
class RBL:
ota_algo: OTAAlgorithm = OTAAlgorithm.NONE
timestamp: float = field(default_factory=time)
name: Union[str, bytes] = "app"
version: Union[str, bytes] = "1.00"
sn: Union[str, bytes] = "0" * 23
data_crc: int = 0
data_hash: int = 0x811C9DC5 # https://github.com/znerol/py-fnvhash/blob/master/fnvhash/__init__.py
raw_size: int = 0
data_size: int = 0
container_size: int = 0
has_part_table: bool = False
def update(self, data: bytes):
self.data_crc = crc32(data, self.data_crc)
for byte in data:
if self.data_size < self.raw_size:
self.data_hash ^= byte
self.data_hash *= 0x01000193
self.data_hash %= 0x100000000
self.data_size += 1
def serialize(self) -> bytes:
if isinstance(self.name, str):
self.name = self.name.encode()
if isinstance(self.version, str):
self.version = self.version.encode()
if isinstance(self.sn, str):
self.sn = self.sn.encode()
# based on https://github.com/khalednassar/bk7231tools/blob/main/bk7231tools/analysis/rbl.py
struct = Struct("<4sII16s24s24sIIII") # without header CRC
rbl = struct.pack(
b"RBL\x00",
self.ota_algo,
int(self.timestamp),
pad_data(self.name, 16, 0x00),
pad_data(self.version, 24, 0x00),
pad_data(self.sn, 24, 0x00),
self.data_crc,
self.data_hash,
self.raw_size,
self.data_size,
)
return rbl + inttole32(crc32(rbl))
@classmethod
def deserialize(cls, data: bytes) -> "RBL":
crc_found = letoint(data[-4:])
data = data[:-4]
crc_expected = crc32(data)
if crc_expected != crc_found:
raise ValueError(
f"Invalid RBL CRC (expected {crc_expected:X}, found {crc_found:X})"
)
struct = Struct("<II16s24s24sIIII") # without magic and header CRC
rbl = cls(*struct.unpack(data[4:]))
rbl.ota_algo = OTAAlgorithm(rbl.ota_algo)
rbl.name = rbl.name.partition(b"\x00")[0].decode()
rbl.version = rbl.version.partition(b"\x00")[0].decode()
rbl.sn = rbl.sn.partition(b"\x00")[0].decode()
return rbl
class BekenBinary:
crypto: BekenCrypto
def __init__(self, coeffs: Union[bytes, str] = None) -> None:
if coeffs:
if isinstance(coeffs, str):
coeffs = bytes.fromhex(coeffs)
if len(coeffs) != 16:
raise ValueError(
f"Invalid length of encryption coefficients: {len(coeffs)}"
)
coeffs = list(map(BitInt, map(betoint, biniter(coeffs, 4))))
self.crypto = BekenCrypto(coeffs)
def crc(self, data: ByteGenerator) -> ByteGenerator:
for block in geniter(data, 32):
crc = CRC16.CMS.calc(block)
yield block
yield inttobe16(crc)
def uncrc(self, data: ByteGenerator, check: bool = True) -> ByteGenerator:
for block in geniter(data, 34):
if check:
crc = CRC16.CMS.calc(block[0:32])
crc_found = betoint(block[32:34])
if crc != crc_found:
print(f"CRC invalid: expected={crc:X}, found={crc_found:X}")
return
yield block[0:32]
def crypt(self, addr: int, data: ByteGenerator) -> ByteGenerator:
for word in geniter(data, 4):
word = letoint(word)
word = self.crypto.encrypt_u32(addr, word)
word = inttole32(word)
yield word
addr += 4
def package(self, f: FileIO, addr: int, size: int, rbl: RBL) -> ByteGenerator:
if not rbl.container_size:
raise ValueError("RBL must have a total size when packaging")
crc_total = 0
# when to stop reading input data
data_end = size
if rbl.has_part_table:
data_end = size - 0xC0 # do not encrypt the partition table
# set RBL size including one 16-byte padding
rbl.raw_size = align_up(size + 16, 32) + 16
# encrypt the input file, padded to 32 bytes
data_crypt_gen = self.crypt(
addr, fileiter(f, size=32, padding=0xFF, count=data_end)
)
# iterate over encrypted 32-byte blocks
for block in geniter(data_crypt_gen, 32):
# add CRC16 and yield
yield from self.crc(block)
crc_total += 2
rbl.update(block)
# temporary buffer for small-size operations
buf = b"\xff" * 16 # add 16 bytes of padding
if rbl.has_part_table:
# add an unencrypted partition table
buf += f.read(0xC0)
# update RBL
rbl.update(buf)
# add last padding with different values
rbl.update(b"\x10" * 16)
# add last padding with normal values
buf += b"\xff" * 16
# yield the temporary buffer
yield from self.crc(buf)
crc_total += 2 * (len(buf) // 32)
# pad the entire container with 0xFF, excluding RBL and its CRC16
pad_size = pad_up(rbl.data_size + crc_total, rbl.container_size) - 102
for _ in range(pad_size):
yield b"\xff"
# yield RBL with CRC16
yield from self.crc(rbl.serialize())
def auto_int(x):
return int(x, 0)
def add_common_args(parser):
parser.add_argument(
"coeffs", type=str, help="Encryption coefficients (hex string, 32 chars)"
)
parser.add_argument("input", type=FileType("rb"), help="Input file")
parser.add_argument("output", type=FileType("wb"), help="Output file")
parser.add_argument("addr", type=auto_int, help="Memory address (dec/hex)")
if __name__ == "__main__":
parser = ArgumentParser(description="Encrypt/decrypt Beken firmware binaries")
sub = parser.add_subparsers(dest="action", required=True)
encrypt = sub.add_parser("encrypt", help="Encrypt binary files without packaging")
add_common_args(encrypt)
encrypt.add_argument("-c", "--crc", help="Include CRC16", action="store_true")
decrypt = sub.add_parser("decrypt", description="Decrypt unpackaged binary files")
add_common_args(decrypt)
decrypt.add_argument(
"-C",
"--no-crc-check",
help="Do not check CRC16 (if present)",
action="store_true",
)
package = sub.add_parser(
"package", description="Package raw binary files as RBL containers"
)
add_common_args(package)
package.add_argument("size", type=auto_int, help="RBL total size (dec/hex)")
package.add_argument(
"-n",
"--name",
type=str,
help="Firmware name (default: app)",
default="app",
required=False,
)
package.add_argument(
"-v",
"--version",
type=str,
help="Firmware version (default: 1.00)",
default="1.00",
required=False,
)
unpackage = sub.add_parser(
"unpackage", description="Unpackage a single RBL container"
)
add_common_args(unpackage)
unpackage.add_argument(
"offset", type=auto_int, help="Offset in input file (dec/hex)"
)
unpackage.add_argument("size", type=auto_int, help="RBL total size (dec/hex)")
args = parser.parse_args()
bk = BekenBinary(args.coeffs)
f: FileIO = args.input
size = stat(args.input.name).st_size
start = time()
if args.action == "encrypt":
print(f"Encrypting '{f.name}' ({size} bytes)")
if args.crc:
print(f" - calculating 32-byte block CRC16...")
gen = bk.crc(bk.crypt(args.addr, f))
else:
print(f" - as raw binary, without CRC16...")
gen = bk.crypt(args.addr, f)
if args.action == "decrypt":
print(f"Decrypting '{f.name}' ({size} bytes)")
if size % 34 == 0:
if args.no_crc_check:
print(f" - has CRC16, skipping checks...")
else:
print(f" - has CRC16, checking...")
gen = bk.crypt(args.addr, bk.uncrc(f, check=not args.no_crc_check))
elif size % 4 != 0:
raise ValueError("Input file has invalid length")
else:
print(f" - raw binary, no CRC")
gen = bk.crypt(args.addr, f)
if args.action == "package":
print(f"Packaging {args.name} '{f.name}' for memory address 0x{args.addr:X}")
rbl = RBL(name=args.name, version=args.version)
if args.name == "bootloader":
rbl.has_part_table = True
print(f" - in bootloader mode; partition table unencrypted")
rbl.container_size = args.size
print(f" - container size: 0x{args.size:X}")
gen = bk.package(f, args.addr, size, rbl)
if args.action == "unpackage":
print(f"Unpackaging '{f.name}' (at 0x{args.offset:X}, size 0x{args.size:X})")
f.seek(args.offset + args.size - 102, SEEK_SET)
rbl = f.read(102)
rbl = b"".join(bk.uncrc(rbl))
rbl = RBL.deserialize(rbl)
print(f" - found '{rbl.name}' ({rbl.version}), size {rbl.data_size}")
f.seek(0, SEEK_SET)
crc_size = (rbl.data_size - 16) // 32 * 34
gen = bk.crypt(args.addr, bk.uncrc(fileiter(f, 32, 0xFF, crc_size)))
written = 0
for data in gen:
args.output.write(data)
written += len(data)
print(f" - wrote {written} bytes in {time()-start:.3f} s")

View File

@@ -1,5 +1,10 @@
# Copyright (c) Kuba Szczodrzyński 2022-06-02.
from io import FileIO
from typing import IO, Generator, Union
ByteGenerator = Generator[bytes, None, None]
def bswap(data: bytes) -> bytes:
"""Reverse the byte array (big-endian <-> little-endian)."""
@@ -137,3 +142,63 @@ def uint32(val):
def uintmax(bits: int) -> int:
"""Get maximum integer size for given bit width."""
return (2**bits) - 1
def biniter(data: bytes, size: int) -> ByteGenerator:
"""Iterate over 'data' in 'size'-bytes long chunks, returning
a generator."""
if len(data) % size != 0:
raise ValueError(
f"Data length must be a multiple of block size ({len(data)} % {size})"
)
for i in range(0, len(data), size):
yield data[i : i + size]
def geniter(gen: Union[ByteGenerator, bytes, IO], size: int) -> ByteGenerator:
"""
Take data from 'gen' and generate 'size'-bytes long chunks.
If 'gen' is a bytes or IO object, it is wrapped using
biniter() or fileiter().
"""
if isinstance(gen, bytes):
yield from biniter(gen, size)
return
if isinstance(gen, IO):
yield from fileiter(gen, size)
return
buf = b""
for part in gen:
if not buf and len(part) == size:
yield part
continue
buf += part
while len(buf) >= size:
yield buf[0:size]
buf = buf[size:]
def fileiter(
f: FileIO, size: int, padding: int = 0x00, count: int = 0
) -> ByteGenerator:
"""
Read data from 'f' and generate 'size'-bytes long chunks.
Pad incomplete chunks with 'padding' character.
Read up to 'count' bytes from 'f', if specified. Data is padded
if not on chunk boundary.
"""
read = 0
while True:
if count and read + size >= count:
yield pad_data(f.read(count % size), size, padding)
return
data = f.read(size)
read += len(data)
if len(data) < size:
# got only part of the block
yield pad_data(data, size, padding)
return
yield data

View File

@@ -1,7 +1,9 @@
# Copyright (c) Kuba Szczodrzyński 2022-06-02.
import json
from typing import Union
from typing import Tuple, Union
SliceLike = Union[slice, str, int]
def merge_dicts(d1, d2, path=None):
@@ -27,3 +29,24 @@ def get(data: dict, path: str):
return data.get(path, None)
key, _, path = path.partition(".")
return get(data.get(key, None), path)
def slice2int(val: SliceLike) -> Tuple[int, int]:
"""Convert a slice-like value (slice, string '7:0' or '3', int '3')
to a tuple of (start, stop)."""
if isinstance(val, int):
return (val, val)
if isinstance(val, slice):
if val.step:
raise ValueError("value must be a slice without step")
if val.start < val.stop:
raise ValueError("start must not be less than stop")
return (val.start, val.stop)
if isinstance(val, str):
if ":" in val:
val = val.split(":")
if len(val) == 2:
return tuple(map(int, val))
elif val.isnumeric():
return (int(val), int(val))
raise ValueError(f"invalid slice format: {val}")