Sign In

.Ckpt To .Safestensor <-> .Safestensor To .Ckpt - Google Colab (esp)

2

.Ckpt To .Safestensor  <-> .Safestensor  To  .Ckpt - Google Colab (esp)

El c贸digo que he desarrollado tiene como objetivo principal convertir un archivo de checkpoint (.ckpt) a un formato de "safetensor" (.safetensors). Un safetensor es un formato seguro y compatible con la biblioteca safetensors que puede ser utilizado para cargar y guardar tensores de PyTorch de manera segura.

Par usar este c贸digo se solicita al usuario que proporcione la direcci贸n del archivo .ckpt que se desea convertir, partimos de la base de que era inseguro, as铆 que l贸gicamente no lo descargaremos en nuestra pc y no conectaremos nuestro drive. Por lo que lo descargaremos de Hugging face en 1 minuto:

Luego de eso Haremos la convercion, usando esto, el archivo simple se descargara como Nomana, dado que casi no hay dependencias y simpre se puede borrar el anterior (Nomana)

#@markdown ### 5锔忊儯 Tranforma el modelo a convertir
import os
import torch
from safetensors.torch import save_file

file_path = '/content/Nomana.ckpt' #@param {type:"string"}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if not os.path.isfile(file_path):
    print("El archivo no existe.")
else:
    if file_path.lower().endswith('.ckpt'):
        print(f'Loading {file_path}..si notas que comenzo el consumo ram, entonces el codigo esta andando, el proceso demora 2 minutos sin gpu y 1 minuto en entornos con gpu (Test: 4 - Gigas)')
        fn = f"{file_path.replace('.ckpt', '')}.safetensors"

        if os.path.isfile(fn):
            print(f'Skipping, as {fn} already exists.')
        else:
            try:
                with torch.no_grad():
                    weights = torch.load(file_path, map_location=device)["state_dict"]
                    print(f'Saving {fn}...')
                    save_file(weights, fn)
            except Exception as ex:
                print(f'ERROR converting {file_path}: {ex}')
    else:
        print("El archivo no es un archivo .ckpt.")

print('Terminaste')

Y por ultimo descargamos el safestensor

A continuaci贸n un resumen de lo 煤nico que importa y jodido de este c贸digo:

  with torch.no_grad():
    weights = torch.load(file_path, map_location=device)["state_dict"]
    print(f'Saving {fn}...')
    save_file(weights, fn)

Aqu铆 se utiliza el bloque with torch.no_grad() para indicar que no se realizar谩n c谩lculos de gradiente durante la carga y el guardado de los pesos del modelo. Esto es 煤til cuando solo estamos interesados en la inferencia y no en la optimizaci贸n o el ajuste de los par谩metros del modelo.

A continuaci贸n, utilizamos torch.load(file_path, map_location=device) para cargar los pesos del archivo .ckpt. Veamos algunas consideraciones importantes de esta l铆nea de c贸digo:

  • file_path: Es la ruta al archivo .ckpt que se va a cargar.

  • map_location=device: Especifica el dispositivo donde queremos cargar los pesos. Si torch.cuda.is_available() es True, se usar谩 la GPU (device = "cuda"), de lo contrario, se utilizar谩 la CPU (device = "cpu").

Luego de cargar los pesos, accedemos al diccionario resultante utilizando ["state_dict"]. Este diccionario contiene los pesos del modelo que se guardaron previamente.

Despu茅s de cargar los pesos, se muestra un mensaje indicando que se est谩 guardando el archivo fn (que es el nombre del archivo .safetensors que se generar谩 a partir del archivo .ckpt original).

Finalmente, utilizamos save_file(weights, fn) para guardar los pesos en el formato .safetensors utilizando la funci贸n save_file de la biblioteca safetensors.

En resumen, esta secci贸n del c贸digo carga los pesos del archivo .ckpt en el dispositivo especificado, los guarda en el formato .safetensors y muestra un mensaje de confirmaci贸n. Es importante tener en cuenta que el bloque with torch.no_grad() se utiliza para asegurarse de que no se realicen c谩lculos de gradiente durante este proceso.

En PyTorch, los modelos de aprendizaje profundo est谩n compuestos por capas y cada capa contiene par谩metros que se conocen como "pesos". Estos pesos son los valores num茅ricos que se utilizan en las operaciones matem谩ticas realizadas por las capas para realizar la inferencia o el entrenamiento del modelo.

El diccionario en PyTorch se utiliza para almacenar y acceder a estos pesos dentro de un modelo. El diccionario se denomina com煤nmente "state_dict" y es una estructura de datos que mapea nombres de capas a sus respectivos pesos. Cada par de clave-valor en el diccionario representa el nombre de una capa y su tensor de pesos correspondiente.

Cuando se guarda un modelo entrenado en PyTorch, generalmente se guarda el diccionario "state_dict" que contiene 煤nicamente los pesos del modelo, ya que estos son los par谩metros principales necesarios para realizar la inferencia. Guardar solo los pesos en lugar del modelo completo ayuda a reducir el tama帽o del archivo y permite una carga m谩s r谩pida.

En el c贸digo mencionado, la l铆nea weights = torch.load(file_path, map_location=device)["state_dict"] carga el diccionario "state_dict" de un archivo .ckpt. Utilizamos la funci贸n torch.load() para cargar los datos almacenados en el archivo, y luego accedemos a la clave "state_dict" para obtener el diccionario de pesos espec铆ficamente.

Una vez que tenemos el diccionario de pesos, podemos utilizarlos para realizar inferencias en nuevos datos o realizar otras operaciones relacionadas con el modelo.

Es importante tener en cuenta que los nombres de las capas en el diccionario "state_dict" se corresponden con los nombres de las capas en el modelo original. Esto significa que, al cargar los pesos en otro modelo, los nombres de las capas deben coincidir para que los pesos se asignen correctamente.

Ind:

Por lo tanto todo se resumen que voz estas sacando el agua de la botella, para ponerla en una jarra.

~~~~~~

Error: - KeyError: 'state_dict'

Para solucionarlo hay que colocar esto en el c贸digo:

if 'state_dict' in weights:
    weights.pop("state_dict')

Explicaci贸n:

Este fragmento de c贸digo es 煤til en situaciones espec铆ficas cuando se carga un archivo de checkpoint (.ckpt) y el diccionario de pesos contiene una clave adicional llamada "state_dict".

La necesidad de utilizar este fragmento de c贸digo surge debido a las diferencias en c贸mo se guardan los modelos en PyTorch. Algunas implementaciones de guardado de modelos, como torch.nn.DataParallel, incluyen autom谩ticamente un diccionario adicional con una clave "state_dict". Esta clave adicional puede causar problemas al intentar cargar los pesos en un modelo sin esa clave.

En esos casos, se utiliza if 'state_dict' in weights para verificar si existe la clave "state_dict" en el diccionario de pesos cargado. Si existe, se utiliza weights.pop("state_dict") para eliminar la clave y el diccionario de pesos se ajusta para que coincida con la estructura esperada. Esto asegura que los pesos se carguen correctamente en el modelo sin la clave adicional.

Sin embargo, es importante tener en cuenta que no todos los archivos de checkpoint tendr谩n una clave "state_dict" adicional. Por lo tanto, este fragmento de c贸digo solo debe utilizarse si se encuentra que el diccionario de pesos contiene dicha clave y es necesario eliminarla para cargar correctamente los pesos en el modelo deseado.

A continuaci贸n, se muestra una lista de implementaciones y casos comunes en los que se puede encontrar este problema:

  1. torch.nn.DataParallel: Esta clase se utiliza para paralelizar el c谩lculo en m煤ltiples GPUs. Al guardar un modelo que se ha envuelto con DataParallel, puede generar un archivo de checkpoint con una clave "state_dict" adicional en el diccionario de pesos.

  2. torch.nn.DataParallel + torch.nn.ModuleDict: Si se utiliza DataParallel en combinaci贸n con un diccionario de m贸dulos (ModuleDict), podr铆a generarse una clave "state_dict" adicional para el diccionario de pesos de cada m贸dulo.

  3. torch.nn.DataParallel + torch.nn.ModuleList: De manera similar a la combinaci贸n anterior, si se utiliza DataParallel con una lista de m贸dulos (ModuleList), podr铆a generarse una clave "state_dict" adicional para cada m贸dulo en la lista.

Eso es todo difruten del codigo.

~~~~~~~~~

La operacion contrai seria convertir .safestensor To .Ckpt (Por que es util?), Basicamente por que hay muchicima documentacion y scrips que se basan en .ckpt y no .safestenosr , por lo que para usar algunos programas nos vamos a encontrar que es un inconveniente trabajar con .safestensor .

Basicamente el programa cuenta con dos partes:

Dependencias y Descargas:

Tranformacion y verificacion :

~~~~~~~~

Ck To Safe :

Open in Gato

Open in Colab

Safe To Ck

Open in Gato

Open in Colab

~~~~~~~~

Financial assistance: Hello everyone!

This is Tomas Agilar speaking, and I'm thrilled to have the opportunity to share my work and passion with all of you. If you enjoy what I do and would like to support me, there are a few ways you can do so:

~~~~~~~~~~~~~~

Pensaba que no te hiba a dar el copy pasta?, claro que no campion aca tenes tu codigo libre de ezfuerso si te salta el Error: - KeyError: 'state_dict' . Solo recorda que puede traer probemas dropiar el dicionario donde se guarda todo lo importante si el mismo no estaba duplicado

import os
import torch
from safetensors.torch import save_file

file_path = '/content/Nomana.ckpt' #@param {type:"string"}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if not os.path.isfile(file_path):
    print("El archivo no existe.")
else:
    if file_path.lower().endswith('.ckpt'):
        print(f'Loading {file_path}..si notas que comenzo el consumo ram, entonces el codigo esta andando, el proceso demora 2 minutos sin gpu y 1 minuto en entornos con gpu (Test: 4 - Gigas)')
        fn = f"{file_path.replace('.ckpt', '')}.safetensors"

        if os.path.isfile(fn):
            print(f'Skipping, as {fn} already exists.')
        else:
            try:
                with torch.no_grad():
                    weights = torch.load(file_path, map_location=device)
                    if 'state_dict' in weights:
                        weights.pop("state_dict")
                    print(f'Saving {fn}...')
                    save_file(weights, fn)
            except Exception as ex:
                print(f'ERROR converting {file_path}: {ex}')
    else:
        print("El archivo no es un archivo .ckpt.")

print('Terminaste')

2

Comments