def bin_to_safetensors()

in Stable-Diffusion-Vertex/hpo/diffusers/train_diffusers.py [0:0]


def bin_to_safetensors(output_path):
    newDict = dict();
    checkpoint = torch.load(output_path + '/pytorch_lora_weights.bin');
    for idx, key in enumerate(checkpoint):
      newKey = re.sub('\.processor\.', '_', key);
      newKey = re.sub('mid_block\.', 'mid_block_', newKey);
      newKey = re.sub('_lora.up.', '.lora_up.', newKey);
      newKey = re.sub('_lora.down.', '.lora_down.', newKey);
      newKey = re.sub('\.(\d+)\.', '_\\1_', newKey);
      newKey = re.sub('to_out', 'to_out_0', newKey);
      newKey = 'lora_unet_'+newKey;

      newDict[newKey] = checkpoint[key];

    newLoraName = 'pytorch_lora_weights.safetensors';
    print("Saving " + newLoraName);
    save_file(newDict, output_path + '/' + newLoraName);