Spaces:
Build error
Build error
| # Converts an RWKV model checkpoint in PyTorch format to an rwkv.cpp compatible file. | |
| # Usage: python convert_pytorch_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M.bin float32 | |
| # Get model checkpoints from https://huggingface.co/BlinkDL | |
| # See FILE_FORMAT.md for the documentation on the file format. | |
| import argparse | |
| import struct | |
| import torch | |
| from typing import Dict | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description='Convert an RWKV model checkpoint in PyTorch format to an rwkv.cpp compatible file') | |
| parser.add_argument('src_path', help='Path to PyTorch checkpoint file') | |
| parser.add_argument('dest_path', help='Path to rwkv.cpp checkpoint file, will be overwritten') | |
| parser.add_argument('data_type', help='Data type, float16 or float32', type=str, choices=['float16', 'float32'], default='float32') | |
| return parser.parse_args() | |
| def get_layer_count(state_dict: Dict[str, torch.Tensor]) -> int: | |
| n_layer = 0 | |
| while f'blocks.{n_layer}.ln1.weight' in state_dict: | |
| n_layer += 1 | |
| assert n_layer > 0 | |
| return n_layer | |
| def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_type: str) -> None: | |
| emb_weight: torch.Tensor = state_dict['emb.weight'] | |
| n_layer = get_layer_count(state_dict) | |
| n_vocab = emb_weight.shape[0] | |
| n_embed = emb_weight.shape[1] | |
| with open(dest_path, 'wb') as out_file: | |
| out_file.write(struct.pack( | |
| # Disable padding with '=' | |
| '=iiiiii', | |
| # Magic: 'ggmf' in hex | |
| 0x67676d66, | |
| 101, | |
| n_vocab, | |
| n_embed, | |
| n_layer, | |
| 1 if data_type == 'float16' else 0 | |
| )) | |
| for k in state_dict.keys(): | |
| tensor = state_dict[k].float() | |
| # Same processing as in "RWKV_in_150_lines.py" | |
| if '.time_' in k: | |
| # (1, 1, n_embed) -> (n_embed) | |
| tensor = tensor.squeeze() | |
| if '.time_decay' in k: | |
| tensor = -torch.exp(tensor) | |
| # Keep 1-dim vectors in fp32 | |
| if data_type == 'float16' and len(tensor.shape) > 1: | |
| tensor = tensor.half() | |
| shape = tensor.shape | |
| print(f'Writing {k}, shape {shape}, type {tensor.dtype}') | |
| k_encoded: bytes = k.encode('utf-8') | |
| out_file.write(struct.pack( | |
| '=iii', | |
| len(shape), | |
| len(k_encoded), | |
| 1 if tensor.dtype == torch.float16 else 0 | |
| )) | |
| # Dimension order is reversed here: | |
| # * PyTorch shape is (x rows, y columns) | |
| # * ggml shape is (y elements in a row, x elements in a column) | |
| # Both shapes represent the same tensor. | |
| for dim in reversed(tensor.shape): | |
| out_file.write(struct.pack('=i', dim)) | |
| out_file.write(k_encoded) | |
| tensor.numpy().tofile(out_file) | |
| def main() -> None: | |
| args = parse_args() | |
| print(f'Reading {args.src_path}') | |
| state_dict: Dict[str, torch.Tensor] = torch.load(args.src_path, map_location='cpu') | |
| write_state_dict(state_dict, args.dest_path, args.data_type) | |
| print('Done') | |
| if __name__ == "__main__": | |
| main() |