| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import argparse |
|
|
| import torch |
|
|
| from transformers import ( |
| AddedToken, |
| AutoConfig, |
| AutoTokenizer, |
| ) |
| from configuration_llava import LlavaConfig |
| from modeling_llava import LlavaForConditionalGeneration |
|
|
|
|
| KEYS_TO_MODIFY_MAPPING = { |
| "transformer.vision_tower.vision_tower": "vision_model", |
| "transformer.mm_projector": "multi_modal_projector", |
| "transformer": "language_model.transformer", |
| "lm_head": "language_model.lm_head", |
| "model.model": "language_model.transformer", |
| "multi_modal_projector.0": "multi_modal_projector.linear_1", |
| "multi_modal_projector.2": "multi_modal_projector.linear_2", |
| } |
|
|
|
|
| def convert_state_dict_to_hf(state_dict): |
| new_state_dict = {} |
| for key, value in state_dict.items(): |
| for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items(): |
| if key_to_modify in key: |
| key = key.replace(key_to_modify, new_key) |
|
|
| new_state_dict[key] = value |
| return new_state_dict |
|
|
|
|
| def convert_llava_llama_to_hf(text_model_id, vision_model_id, projector_tokens_num, output_path, old_state_dict_path): |
| torch.set_default_dtype(torch.float16) |
| text_config = AutoConfig.from_pretrained(text_model_id, trust_remote_code=True) |
|
|
| tokenizer = AutoTokenizer.from_pretrained(text_model_id) |
| tokenizer.add_tokens(AddedToken("<image>", special=True, normalized=False), special_tokens=True) |
| tokenizer.add_special_tokens({"pad_token": "<pad>"}) |
|
|
| config = LlavaConfig(text_config=text_config, vocab_size=51200, vision_tower_name=vision_model_id, projector_tokens_num=projector_tokens_num) |
| config.text_config.vocab_size = config.vocab_size |
|
|
| with torch.device("cuda"): |
| model = LlavaForConditionalGeneration(config) |
| |
| state_dict = torch.load(old_state_dict_path, map_location="cpu") |
| state_dict = convert_state_dict_to_hf(state_dict) |
| model.load_state_dict(state_dict, strict=True, assign=True) |
|
|
| model.config.vocab_size = model.config.vocab_size |
| model.config.text_config.vocab_size = model.config.text_config.vocab_size |
|
|
| model.save_pretrained(output_path) |
| tokenizer.save_pretrained(output_path) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| "--text_model_id", |
| help="Hub location of the text model", |
| ) |
| parser.add_argument( |
| "--vision_model_id", |
| help="Hub location of the vision model", |
| ) |
| parser.add_argument( |
| "--output_path", |
| help="Location of the converted model", |
| ) |
| parser.add_argument( |
| "--old_state_dict_path", |
| help="Location on the hub of the raw state dict of the original model. The filename needs to be `model_state_dict.bin`", |
| ) |
| parser.add_argument( |
| "--tokens_num", |
| type=int, |
| default=1 |
| ) |
| args = parser.parse_args() |
| convert_llava_llama_to_hf(args.text_model_id, args.vision_model_id, args.tokens_num, args.output_path, args.old_state_dict_path) |
|
|
|
|
| if __name__ == "__main__": |
| main() |