File size: 5,481 Bytes
199c8cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84b5521
 
199c8cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

def find_nearest(x, y):
    """
    Args:
        x (tensor): size (N, D)
        y (tensor): size (M, D)
    """
    dist_mat = torch.cdist(x, y)
    min_dist, indices = torch.min(dist_mat, dim=-1)
    return min_dist, indices

class Quantizer(nn.Module):
    def __init__(self, TYPE: str = "vq", 
                 code_dim: int = 4, num_code: int = 1024,
                 num_group: int = None, tokens_per_data: int = 4,
                 auto_reset: bool = True, reset_ratio: float = 0.2,
                 reset_noise: float = 0.1
        ):
        super(Quantizer, self).__init__()
        self.code_dim = code_dim
        self.num_code = num_code
        self.tokens_per_data = tokens_per_data
        self.num_group = num_group if num_group is not None else tokens_per_data

        self.auto_reset = auto_reset
        self.reset_ratio = reset_ratio
        self.reset_size = max(int(reset_ratio * num_code), 1)
        self.reset_noise = reset_noise

        assert self.tokens_per_data % self.num_group == 0, "tokens_per_data must be divisible by num_group"

        self.find_nearest = torch.vmap(find_nearest, in_dims=0, out_dims=0, chunk_size=2)
        index_offset = torch.tensor(
            [i * self.num_code for i in range(self.num_group)]
        ).long()
        self.register_buffer("index_offset", index_offset)
        self.register_buffer("count", torch.zeros(self.num_group, self.num_code))
        self.embeddings = nn.Parameter(
            torch.randn(self.num_group, self.num_code, self.code_dim)
        )

    def prepare_codebook(self, data, method: str = "random"):
        """
        Args:
            data (tensor): size (N, TPD, D)
        """
        # sample num-code samples from the data
        N, TPD, D = data.size()
        assert TPD == self.tokens_per_data and D == self.code_dim, \
            f"input size {data.size()} does not match the expected size {(N, self.tokens_per_data, self.code_dim)}"
        assert N >= self.num_code, f"num_code {self.num_code} is larger than the data size {N}"

        for i in range(self.num_group):
            if method == "random":
                # sample num_code samples from the data
                indices = torch.randint(0, N, (self.num_code,))
                # sample the codebook from the data
                self.embeddings.data[i] = data[indices][:, i].clone().detach()
            elif method == "kmeans":
                # use kmeans to sample the codebook: sklearn
                from sklearn.cluster import KMeans
                kmeans = KMeans(n_clusters=self.num_code, n_init=10, max_iter=100)
                kmeans.fit(data[:, i].reshape(-1, self.code_dim).cpu().numpy())
                # sample the codebook from the data
                self.embeddings.data[i] = torch.tensor(kmeans.cluster_centers_).to(data.device)
        nn.init.ones_(self.count)

    def forward(self, x):
        """
        Args:
            x (tensor): size (BS, TPD, D)
        Returns:
            x_quant: size (BS, TPD, D)
            indices: size (BS, TPD)
        """
        BS, TPD, D = x.size()
        assert TPD == self.tokens_per_data and D == self.code_dim, \
            f"input size {x.size()} does not match the expected size {(BS, self.tokens_per_data, self.code_dim)}"

        # compute the indices
        x = x.view(-1, self.num_group, self.code_dim) # (BS*nS, nG, D)
        x = x.permute(1, 0, 2) # (nG, BS*nS, D)
        with torch.no_grad():
            # find the nearest codebook
            dist_min, indices = self.find_nearest(x, self.embeddings)
            
            # update count buffer
            self.count.scatter_add_(dim=-1, index=indices, 
                                    src=torch.ones_like(indices).float())

        # compute the embedding
        indices = indices.permute(1, 0) # (BS*nS, nG)
        indices_offset = indices + self.index_offset
        indices_offset = indices_offset.flatten()
        x_quant = self.embeddings.view(-1, self.code_dim)[indices_offset] # (BS*nS*nG, D)
        x_quant = x_quant.reshape(BS, self.tokens_per_data, self.code_dim)
        indices = indices.reshape(BS, -1)

        return dict(
            x_quant=x_quant,
            indices=indices
        )
    
    def reset_code(self):
        """
        Update the codebook to update the un-activated codebook
        """
        for i in range(self.num_group):
            # sort the codebook via the count numbers
            _, sorted_idx = torch.sort(self.count[i], descending=True)
            reset_size = int(min(self.reset_size, torch.sum(self.count[i] < 0.5)))
            if reset_size > 0.5:
                largest_topk = sorted_idx[:reset_size]
                smallest_topk = sorted_idx[-reset_size:]

                # move the un-activated codebook to the most frequent codebook
                self.embeddings.data[i][smallest_topk] = (
                    self.embeddings.data[i][largest_topk] + 
                    self.reset_noise * torch.randn_like(self.embeddings.data[i][largest_topk]) 
                )

    def reset_count(self):
        # reset the count buffer
        self.count.zero_()
    
    def reset(self):
        if self.auto_reset:
            self.reset_code()
            self.reset_count()
        
    def get_usage(self):
        usage = (self.count > 0.5).sum(dim=-1).float() / self.num_code
        return usage.mean().item()