I attempted to implement a RSA algorithm in Python 3.12. I first implemented a textbook RSA algorithm, which was done successfully. (I verified this with many repeated attempts using various keys and messages, all of which encrypted and decrypted successfully.)
However, since I implemented Optimal Asymmetric Encryption Padding (OAEP) encoding and decoding, lHash is occasionally not equal to lHash' during the OAEP decoding process. (I'm using terminology from Wikipedia's page on OAEP: https://en.wikipedia.org/wiki/Optimal_asymmetric_encryption_padding). One quick glace at my code will reveal that many of the OAEP related functions were "inspired by" or taken directly from the following github project: https://gist.github.com/ppoffice/e10e0a418d5dafdd5efe9495e962d3d2.
Approximately half of the time, the lineassert(lhash == lhash_prime)will cause an assertion error. When running the program 100 times, 46% of the time, the assertion error was raised. I have a few examples of key values which worked as well as ones that did not.
I've attempted analysing patterns in the n, e, and d values that the RSA key generation produces which raise the assertion error. I believe that the n values in particular would be helpful, given OAEP uses the length of n as a significant part of the process. However with such large values, it's difficult for a beginner programmer like me to make sense of them.
When I didn't actually encrypt and decrypt messages with both textbook RSA and OAEP, and instead I simply encoded and decoded messages with only OAEP, the process worked fine. Additionally, any tests I did only using textbook RSA worked as well.
Below is the code for a minimal, reproducible example. I'm sorry, but even though I've tried my best to reduce how long it is (removed checks to ensure the primes p and q are actually secure primes, etc.), it's still quite long if the bug is to be reproduced.
import randomfrom math import ceilimport hashlibimport osfrom typing import Callabledef byte_len(n: int) -> int: return ceil(n.bit_length() / 8)def get_n_bit_rand_num(n: int) -> int: return random.randrange(2**(n-1)+1,2**n-1)def rabin_miller_composite_test(a: int, m: int, k: int, n: int) -> bool: if (pow(a,m,n) == 1): return False for i in range(k): if (pow(a,2**i*m,n) == n-1): return False return Truedef probablistic_is_prime_test(n: int) -> bool: first_primes_list = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307, 311, 313, 317, 331, 337, 347, 349] for divisor in first_primes_list: if n % divisor == 0: return False k = 0 m = n-1 while (m % 2 == 0): m >>= 1 k += 1 iterations = 20 for _ in range(iterations): a = random.randrange(2,n-1) if rabin_miller_composite_test(a,m,k,n): return False return Truedef get_random_large_prime() -> int: num_is_prime = False num = 0 while(num_is_prime == False): num = get_n_bit_rand_num(1024) if probablistic_is_prime_test(num): num_is_prime = True return numdef euclidean_algorithm_GCD(larger_num: int, smaller_num: int) -> int: if (smaller_num == 0): return larger_num else: return euclidean_algorithm_GCD(smaller_num,larger_num % smaller_num)def extended_euclidean_algorithm_second_num_of_linear_combination(larger_num: int, smaller_num: int) -> int: s = 0 r = smaller_num old_r = larger_num old_s = 1 quotient = 0 temp = 0 while (r != 0): quotient = old_r // r temp = old_r old_r = r r = temp - quotient * r temp = old_s old_s = s s = temp - quotient * s second_num = (old_r - old_s * larger_num) // smaller_num return second_numdef generate_keys() -> tuple[int, int, int]: p = get_random_large_prime() q = get_random_large_prime() n = p*q phi_of_n = (p-1) * (q-1) e = 65537 while euclidean_algorithm_GCD(phi_of_n,e) != 1: e += 1 d = extended_euclidean_algorithm_second_num_of_linear_combination(phi_of_n,e) % phi_of_n return n, e, ddef textbook_encrypt_message(message: bytes, e: int, n: int) -> int: int_message = int.from_bytes(message, 'little') return pow(int_message,e,n)def textbook_decrypt_message(encrypted_message: int, d: int, n: int) -> bytes: int_message = pow(encrypted_message, d, n) return int_message.to_bytes(byte_len(int_message), 'little')def encrypt_message_oaep(message: bytes, e: int, n: int) -> int: n_byte_length = byte_len(n) padded_message = oaep_encode(message,n_byte_length) return textbook_encrypt_message(padded_message,e,n)def decrypt_message_oaep(encrypted_message: int, d: int, n: int) -> str: encoded_message = textbook_decrypt_message(encrypted_message, d, n) encoded_message_as_bytes = encoded_message n_byte_length = byte_len(n) message = oaep_decode(encoded_message_as_bytes,n_byte_length) return message.decode()def bytewise_xor(data: bytes, mask: bytes) -> bytes: masked = b"" for i in range(max(len(data),len(mask))): if i < len(data) and i < len(mask): masked += (data[i] ^ mask[i]).to_bytes(1, byteorder = 'big') elif i < len(data): masked += data[i].to_bytes(1, byteorder="big") else: break return maskeddef sha1(m: bytes) -> bytes:'''SHA-1 hash function''' hasher = hashlib.sha1() hasher.update(m) return hasher.digest()def mgf1(seed: bytes, mlen: int, f_hash: Callable = sha1) -> bytes: '''MGF1 mask generation function with SHA-1''' t = b'' hlen = len(f_hash(b'')) for c in range(0, ceil(mlen / hlen)): _c = c.to_bytes(4, byteorder="big") t += f_hash(seed + _c) return t[:mlen]def oaep_encode(message: bytes, k: int, label: bytes = b"", hash_func: Callable = sha1, mgf: Callable = mgf1) -> bytes: lhash = hash_func(label) padding_string = (k - len(message)-2*len(lhash)-2) * b"\x00" data_block = lhash + padding_string + b"\x01" + message seed = os.urandom(len(lhash)) data_block_mask = mgf(seed,k-len(lhash)-1,hash_func) masked_data_block = bytewise_xor(data_block,data_block_mask) seed_mask = mgf(masked_data_block,len(lhash),hash_func) masked_seed = bytewise_xor(seed,seed_mask) return b"\x00" + masked_seed + masked_data_blockdef oaep_decode(encoded_message: bytes, k: int, label: bytes = b"", hash_func: Callable = sha1, mgf: Callable = mgf1) -> bytes: lhash = hash_func(label) masked_seed = encoded_message[1:1 + len(lhash)] masked_data_block = encoded_message[1+len(lhash):] seed_mask = mgf(masked_data_block,len(lhash),hash_func) seed = bytewise_xor(masked_seed,seed_mask) data_block_mask = mgf(seed,k-len(lhash)-1,hash_func) data_block = bytewise_xor(masked_data_block, data_block_mask) lhash_prime = data_block[:len(lhash)] assert(lhash == lhash_prime) i = len(lhash) while i < len(data_block): if data_block[i] == 0: i += 1 continue elif data_block[i] == 1: i += 1 break else: raise Exception('This should never happen.') return data_block[i:][n, e, d] = generate_keys()print("n: ", n)print("e: ", e)print("d: ", d)message = "Imagine that this is some secure test message"oaep_encrypted_message = encrypt_message_oaep(message.encode(), e, n)print(oaep_encrypted_message)print(decrypt_message_oaep(oaep_encrypted_message, d, n))