Im trying to code a OCR project to identify digits without using any extra libraries. Im using Pygame to draw and render the number to identify. I managed to get all of it working, but for some weird reason, some specific numbers like 9 or 6 never ever get recognized by the algorithm(not even when I set the k really high, none of the possibilities is that number). I have absolutely no idea why, and any help would be really appreciated
import timeimport pygamefrom sys import exitimport randomimport mathnumber_comparisons = 60000DATA_DIR = r"C:/Users/----/Downloads/OCR/"TEST_DIR = r"C:/Users/----/Downloads/OCR/test/"TEST_DATA_FILENAME = DATA_DIR +"t10k-images.idx3-ubyte"TEST_LABELS_FILENAME = DATA_DIR +"t10k-labels.idx1-ubyte"TRAIN_DATA_FILENAME = DATA_DIR +"train-images.idx3-ubyte"TRAIN_LABELS_FILENAME = DATA_DIR +"train-labels.idx1-ubyte"start_time = time.time()DEBUG = True#starts pygame (images and sounds)pygame.init()pygame.mixer.init()#create screen variablewidth = 784height = 850screen = pygame.display.set_mode((width,height))#ponerle titulopygame.display.set_caption("Optical Character Recognition")#reloj/framerateclock = pygame.time.Clock()def read_labels(filename,n_max_labels = None): labels = [] #variable que guarda todas las imagenes with open(filename, "rb") as f: #abrir el fichero filename como f, y leerlo en binario ("rb") _ =f.read(4) #numero inutil (representa algo que no necesitamos) #los siguientes 12 bytes representan el numero de imagenes, el numero de filas y de columnas n_labels = bytes_to_int(f.read(4)) if n_max_labels: n_labels = n_max_labels for label_idx in range(n_labels): label = f.read(1) labels.append(label) return labelscount = 0def read_images(filename,n_max_images = None): global count images = [] #variable que guarda todas las imagenes with open(filename, "rb") as f: #abrir el fichero filename como f, y leerlo en binario ("rb") _ =f.read(4) #numero inutil (representa algo que no necesitamos) #los siguientes 12 bytes representan el numero de imagenes, el numero de filas y de columnas n_images = bytes_to_int(f.read(4)) if n_max_images: n_images = n_max_images n_rows = bytes_to_int(f.read(4)) n_columns = bytes_to_int(f.read(4)) for image_idx in range(n_images): image = []#variable que guarda la imagen actual for row_idx in range(n_rows): row = []#variable que guarda la columna actual for column_idx in range(n_columns): count += 1 pixel = f.read(1) #leemos el pixel actual de 8 bits y lo apendizamos a la row row.append(pixel) image.append(row)#metemos la row en la image images.append(image)#metemos la image en el conjunto de images return imagesdef bytes_to_int(byte_data): if byte_data == 0: return 0 elif byte_data == 200: return 255 elif byte_data == 200: return 200 else: return int.from_bytes(byte_data,"big")def pasar_lista_unidimensional(X): lista = [] for i in range(len(X)): for j in range(len(X[0])): lista.append(X[i][j]) return [lista]def pasar_lista_unidimensional2(X): return [aplanar_lista(sample) for sample in X]def aplanar_lista(l): return [pixel for sublist in l for pixel in sublist]def dist(x,y): temp2 = [] for x_i,y_i in zip(x,y): temp2.append((bytes_to_int(x_i) - bytes_to_int(y_i)) **2) return sum(temp2)**0.5 return sum((bytes_to_int(x_i) - bytes_to_int(y_i)) **2 for x_i,y_i in zip(x,y))**0.5 #distancia euclidesdef distancia_entre_samples(X_train,test_sample): return [dist(train_sample,test_sample) for train_sample in X_train] #por todas las imagenes, calculamos su distancia arribadef most_frequent_element(list): return max(list, key= list.count)def knn(X_train,y_train,X_test, k = 3): y_pred = [] #la prediccion que tenemos a los x_test print("Using the knn algorithm to determine k nearest numbers to drawing...") for test_sample_idx,test_sample in enumerate(X_test): training_distances = distancia_entre_samples(X_train,test_sample) #queremos conseguir las distancias a todos los puntos sorted_distance_indices = [ pair[0] for pair in sorted(enumerate(training_distances), key = lambda x: x[1]) ]#escogemos la menor distancia candidates = [bytes_to_int(y_train[idx]) for idx in sorted_distance_indices[:k]] # k mejores candidatos print("Top k choices were", candidates) top_candidate = most_frequent_element(candidates) y_pred.append(top_candidate) #apuntamos a predicción return y_preddef main(): global X_test print("Reading training files...") #"X" es igual al dataset y "y" es el label asignado X_train = read_images(TRAIN_DATA_FILENAME,number_comparisons) y_train = read_labels(TRAIN_LABELS_FILENAME,number_comparisons) y_test = read_labels(TEST_LABELS_FILENAME,1) print("Converting drawing to 2D grid") X_train = pasar_lista_unidimensional2(X_train) #queremos pasar la matriz de valores a una matriz unidimensional y_pred = knn(X_train,y_train,X_test,15) print("The number you have just written is: " ,y_pred) print("Number of iterations: ", count)image_array = [] font = pygame.font.Font(None, 24) if __name__ == "__main__": passfor i in range(0,784,28): pygame.draw.line(screen,"white",(0,i),(784,i))for j in range(0,width,28): image_array.append([0]*28) pygame.draw.line(screen,"white",(j,0),(j,784))print(type(image_array),type(image_array[0]),type(image_array[0][0]))def draw(x,y): pygame.draw.rect(screen,"white",(x-x%28,y-y%28,28,28)) image_array[math.trunc(y/28)][math.trunc(x/28)] = 200 pygame.draw.rect(screen,"white",(x-x%28+28,y-y%28,28,28)) if math.trunc(x/28)+1 < 28: image_array[math.trunc(y/28)][math.trunc(x/28)+1] = 200 pygame.draw.rect(screen,"white",(x-x%28-28,y-y%28,28,28)) if math.trunc(x/28)-1 >= 0: image_array[math.trunc(y/28)][math.trunc(x/28)-1] = 200 pygame.draw.rect(screen,"white",(x-x%28,y-y%28+28,28,28)) if math.trunc(y/28)+1 < 28: image_array[math.trunc(y/28)+1][math.trunc(x/28)] = 200 pygame.draw.rect(screen,"white",(x-x%28,y-y%28-28,28,28)) if math.trunc(y/28)-1 >= 0: image_array[math.trunc(y/28)-1][math.trunc(x/28)] = 200button = Falsetext = font.render("Guess Number", True, (0, 0, 0))while True: #para que se pueda cerrar for event in pygame.event.get(): if event.type == pygame.QUIT: pygame.QUIT() exit() if event.type == pygame.MOUSEBUTTONDOWN: button = True elif event.type == pygame.MOUSEBUTTONUP: button = False if button == True: pos = pygame.mouse.get_pos() if pos[1] < 756: draw(pos[0],pos[1]) else: X_test = pasar_lista_unidimensional(image_array) print(X_test) print("Loading...") main() pygame.draw.rect(screen,"red",(0,784,784,66)) pygame.display.update() #framerate clock.tick(60)