Quantcast
Channel: Active questions tagged python - Stack Overflow
Viewing all articles
Browse latest Browse all 23131

RSA Implementation with OAEP occasionally produces an error with lhash and lhash'

$
0
0

I attempted to implement a RSA algorithm in Python 3.12. I first implemented a textbook RSA algorithm, which was done successfully. (I verified this with many repeated attempts using various keys and messages, all of which encrypted and decrypted successfully.)

However, since I implemented Optimal Asymmetric Encryption Padding (OAEP) encoding and decoding, lHash is occasionally not equal to lHash' during the OAEP decoding process. (I'm using terminology from Wikipedia's page on OAEP: https://en.wikipedia.org/wiki/Optimal_asymmetric_encryption_padding). One quick glace at my code will reveal that many of the OAEP related functions were "inspired by" or taken directly from the following github project: https://gist.github.com/ppoffice/e10e0a418d5dafdd5efe9495e962d3d2.

Approximately half of the time, the lineassert(lhash == lhash_prime)will cause an assertion error. When running the program 100 times, 46% of the time, the assertion error was raised. I have a few examples of key values which worked as well as ones that did not.

I've attempted analysing patterns in the n, e, and d values that the RSA key generation produces which raise the assertion error. I believe that the n values in particular would be helpful, given OAEP uses the length of n as a significant part of the process. However with such large values, it's difficult for a beginner programmer like me to make sense of them.

When I didn't actually encrypt and decrypt messages with both textbook RSA and OAEP, and instead I simply encoded and decoded messages with only OAEP, the process worked fine. Additionally, any tests I did only using textbook RSA worked as well.

Below is the code for a minimal, reproducible example. I'm sorry, but even though I've tried my best to reduce how long it is (removed checks to ensure the primes p and q are actually secure primes, etc.), it's still quite long if the bug is to be reproduced.

import randomfrom math import ceilimport hashlibimport osfrom typing import Callabledef byte_len(n: int) -> int:    return ceil(n.bit_length() / 8)def get_n_bit_rand_num(n: int) -> int:    return random.randrange(2**(n-1)+1,2**n-1)def rabin_miller_composite_test(a: int, m: int, k: int, n: int) -> bool:    if (pow(a,m,n) == 1):         return False     for i in range(k):        if (pow(a,2**i*m,n) == n-1):            return False    return Truedef probablistic_is_prime_test(n: int) -> bool:    first_primes_list = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29,                     31, 37, 41, 43, 47, 53, 59, 61, 67,                     71, 73, 79, 83, 89, 97, 101, 103,                     107, 109, 113, 127, 131, 137, 139,                     149, 151, 157, 163, 167, 173, 179,                     181, 191, 193, 197, 199, 211, 223,                     227, 229, 233, 239, 241, 251, 257,                     263, 269, 271, 277, 281, 283, 293,                     307, 311, 313, 317, 331, 337, 347, 349]    for divisor in first_primes_list:        if n % divisor == 0:            return False    k = 0    m = n-1    while (m % 2 == 0):        m >>= 1         k += 1    iterations = 20     for _ in range(iterations):        a = random.randrange(2,n-1)         if rabin_miller_composite_test(a,m,k,n):            return False    return Truedef get_random_large_prime() -> int:    num_is_prime = False    num = 0    while(num_is_prime == False):        num = get_n_bit_rand_num(1024)        if probablistic_is_prime_test(num):            num_is_prime = True    return numdef euclidean_algorithm_GCD(larger_num: int, smaller_num: int) -> int:    if (smaller_num == 0):        return larger_num    else:        return euclidean_algorithm_GCD(smaller_num,larger_num % smaller_num)def extended_euclidean_algorithm_second_num_of_linear_combination(larger_num: int, smaller_num: int) -> int:    s = 0    r = smaller_num    old_r = larger_num    old_s = 1    quotient = 0    temp = 0    while (r != 0):        quotient = old_r // r        temp = old_r        old_r = r        r = temp - quotient * r        temp = old_s        old_s = s        s = temp - quotient * s    second_num = (old_r - old_s * larger_num) // smaller_num    return second_numdef generate_keys() -> tuple[int, int, int]:    p = get_random_large_prime()    q = get_random_large_prime()    n = p*q    phi_of_n = (p-1) * (q-1)    e = 65537    while euclidean_algorithm_GCD(phi_of_n,e) != 1:        e += 1    d = extended_euclidean_algorithm_second_num_of_linear_combination(phi_of_n,e) % phi_of_n    return n, e, ddef textbook_encrypt_message(message: bytes, e: int, n: int) -> int:    int_message = int.from_bytes(message, 'little')    return pow(int_message,e,n)def textbook_decrypt_message(encrypted_message: int, d: int, n: int) -> bytes:    int_message = pow(encrypted_message, d, n)    return int_message.to_bytes(byte_len(int_message), 'little')def encrypt_message_oaep(message: bytes, e: int, n: int) -> int:    n_byte_length = byte_len(n)    padded_message = oaep_encode(message,n_byte_length)    return textbook_encrypt_message(padded_message,e,n)def decrypt_message_oaep(encrypted_message: int, d: int, n: int) -> str:    encoded_message = textbook_decrypt_message(encrypted_message, d, n)    encoded_message_as_bytes = encoded_message    n_byte_length = byte_len(n)    message = oaep_decode(encoded_message_as_bytes,n_byte_length)    return message.decode()def bytewise_xor(data: bytes, mask: bytes) -> bytes:     masked = b""    for i in range(max(len(data),len(mask))):        if i < len(data) and i < len(mask):            masked += (data[i] ^ mask[i]).to_bytes(1, byteorder = 'big')        elif i < len(data):            masked += data[i].to_bytes(1, byteorder="big")        else:            break    return maskeddef sha1(m: bytes) -> bytes:'''SHA-1 hash function'''    hasher = hashlib.sha1()    hasher.update(m)    return hasher.digest()def mgf1(seed: bytes, mlen: int, f_hash: Callable = sha1) -> bytes: '''MGF1 mask generation function with SHA-1'''    t = b''    hlen = len(f_hash(b''))    for c in range(0, ceil(mlen / hlen)):        _c = c.to_bytes(4, byteorder="big")        t += f_hash(seed + _c)    return t[:mlen]def oaep_encode(message: bytes, k: int, label: bytes = b"", hash_func: Callable = sha1, mgf: Callable = mgf1) -> bytes:     lhash = hash_func(label)    padding_string = (k - len(message)-2*len(lhash)-2) * b"\x00"    data_block = lhash + padding_string + b"\x01" + message    seed = os.urandom(len(lhash))    data_block_mask = mgf(seed,k-len(lhash)-1,hash_func)    masked_data_block = bytewise_xor(data_block,data_block_mask)    seed_mask = mgf(masked_data_block,len(lhash),hash_func)    masked_seed = bytewise_xor(seed,seed_mask)    return b"\x00" + masked_seed + masked_data_blockdef oaep_decode(encoded_message: bytes, k: int, label: bytes = b"", hash_func: Callable = sha1, mgf: Callable = mgf1) -> bytes:    lhash = hash_func(label)    masked_seed = encoded_message[1:1 + len(lhash)]    masked_data_block = encoded_message[1+len(lhash):]    seed_mask = mgf(masked_data_block,len(lhash),hash_func)    seed = bytewise_xor(masked_seed,seed_mask)    data_block_mask = mgf(seed,k-len(lhash)-1,hash_func)    data_block = bytewise_xor(masked_data_block, data_block_mask)    lhash_prime = data_block[:len(lhash)]    assert(lhash == lhash_prime)    i = len(lhash)    while i < len(data_block):        if data_block[i] == 0:            i += 1            continue        elif data_block[i] == 1:            i += 1            break        else:            raise Exception('This should never happen.')    return data_block[i:][n, e, d] = generate_keys()print("n: ", n)print("e: ", e)print("d: ", d)message = "Imagine that this is some secure test message"oaep_encrypted_message = encrypt_message_oaep(message.encode(), e, n)print(oaep_encrypted_message)print(decrypt_message_oaep(oaep_encrypted_message, d, n))

Viewing all articles
Browse latest Browse all 23131

Trending Articles



<script src="https://jsc.adskeeper.com/r/s/rssing.com.1596347.js" async> </script>