Line data Source code
1 : /*
2 : * Copyright (c) 2016, Alliance for Open Media. All rights reserved
3 : *
4 : * This source code is subject to the terms of the BSD 2 Clause License and
5 : * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 : * was not distributed with this source code in the LICENSE file, you can
7 : * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 : * Media Patent License 1.0 was not distributed with this source code in the
9 : * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 : */
11 :
12 : #include <math.h>
13 : #include <stdlib.h>
14 :
15 : #include "av1/encoder/cost.h"
16 : #include "av1/encoder/palette.h"
17 :
18 0 : static float calc_dist(const float *p1, const float *p2, int dim) {
19 0 : float dist = 0;
20 : int i;
21 0 : for (i = 0; i < dim; ++i) {
22 0 : const float diff = p1[i] - p2[i];
23 0 : dist += diff * diff;
24 : }
25 0 : return dist;
26 : }
27 :
28 0 : void av1_calc_indices(const float *data, const float *centroids,
29 : uint8_t *indices, int n, int k, int dim) {
30 : int i, j;
31 0 : for (i = 0; i < n; ++i) {
32 0 : float min_dist = calc_dist(data + i * dim, centroids, dim);
33 0 : indices[i] = 0;
34 0 : for (j = 1; j < k; ++j) {
35 0 : const float this_dist =
36 0 : calc_dist(data + i * dim, centroids + j * dim, dim);
37 0 : if (this_dist < min_dist) {
38 0 : min_dist = this_dist;
39 0 : indices[i] = j;
40 : }
41 : }
42 : }
43 0 : }
44 :
45 : // Generate a random number in the range [0, 32768).
46 0 : static unsigned int lcg_rand16(unsigned int *state) {
47 0 : *state = (unsigned int)(*state * 1103515245ULL + 12345);
48 0 : return *state / 65536 % 32768;
49 : }
50 :
51 0 : static void calc_centroids(const float *data, float *centroids,
52 : const uint8_t *indices, int n, int k, int dim) {
53 : int i, j, index;
54 : int count[PALETTE_MAX_SIZE];
55 0 : unsigned int rand_state = (unsigned int)data[0];
56 :
57 0 : assert(n <= 32768);
58 :
59 0 : memset(count, 0, sizeof(count[0]) * k);
60 0 : memset(centroids, 0, sizeof(centroids[0]) * k * dim);
61 :
62 0 : for (i = 0; i < n; ++i) {
63 0 : index = indices[i];
64 0 : assert(index < k);
65 0 : ++count[index];
66 0 : for (j = 0; j < dim; ++j) {
67 0 : centroids[index * dim + j] += data[i * dim + j];
68 : }
69 : }
70 :
71 0 : for (i = 0; i < k; ++i) {
72 0 : if (count[i] == 0) {
73 0 : memcpy(centroids + i * dim, data + (lcg_rand16(&rand_state) % n) * dim,
74 : sizeof(centroids[0]) * dim);
75 : } else {
76 0 : const float norm = 1.0f / count[i];
77 0 : for (j = 0; j < dim; ++j) centroids[i * dim + j] *= norm;
78 : }
79 : }
80 :
81 : // Round to nearest integers.
82 0 : for (i = 0; i < k * dim; ++i) {
83 0 : centroids[i] = roundf(centroids[i]);
84 : }
85 0 : }
86 :
87 0 : static float calc_total_dist(const float *data, const float *centroids,
88 : const uint8_t *indices, int n, int k, int dim) {
89 0 : float dist = 0;
90 : int i;
91 : (void)k;
92 :
93 0 : for (i = 0; i < n; ++i)
94 0 : dist += calc_dist(data + i * dim, centroids + indices[i] * dim, dim);
95 :
96 0 : return dist;
97 : }
98 :
99 0 : void av1_k_means(const float *data, float *centroids, uint8_t *indices, int n,
100 : int k, int dim, int max_itr) {
101 : int i;
102 : float this_dist;
103 : float pre_centroids[2 * PALETTE_MAX_SIZE];
104 : uint8_t pre_indices[MAX_SB_SQUARE];
105 :
106 0 : av1_calc_indices(data, centroids, indices, n, k, dim);
107 0 : this_dist = calc_total_dist(data, centroids, indices, n, k, dim);
108 :
109 0 : for (i = 0; i < max_itr; ++i) {
110 0 : const float pre_dist = this_dist;
111 0 : memcpy(pre_centroids, centroids, sizeof(pre_centroids[0]) * k * dim);
112 0 : memcpy(pre_indices, indices, sizeof(pre_indices[0]) * n);
113 :
114 0 : calc_centroids(data, centroids, indices, n, k, dim);
115 0 : av1_calc_indices(data, centroids, indices, n, k, dim);
116 0 : this_dist = calc_total_dist(data, centroids, indices, n, k, dim);
117 :
118 0 : if (this_dist > pre_dist) {
119 0 : memcpy(centroids, pre_centroids, sizeof(pre_centroids[0]) * k * dim);
120 0 : memcpy(indices, pre_indices, sizeof(pre_indices[0]) * n);
121 0 : break;
122 : }
123 0 : if (!memcmp(centroids, pre_centroids, sizeof(pre_centroids[0]) * k * dim))
124 0 : break;
125 : }
126 0 : }
127 :
128 0 : static int float_comparer(const void *a, const void *b) {
129 0 : const float fa = *(const float *)a;
130 0 : const float fb = *(const float *)b;
131 0 : return (fa > fb) - (fa < fb);
132 : }
133 :
134 0 : int av1_remove_duplicates(float *centroids, int num_centroids) {
135 : int num_unique; // number of unique centroids
136 : int i;
137 0 : qsort(centroids, num_centroids, sizeof(*centroids), float_comparer);
138 : // Remove duplicates.
139 0 : num_unique = 1;
140 0 : for (i = 1; i < num_centroids; ++i) {
141 0 : if (centroids[i] != centroids[i - 1]) { // found a new unique centroid
142 0 : centroids[num_unique++] = centroids[i];
143 : }
144 : }
145 0 : return num_unique;
146 : }
147 :
148 0 : int av1_count_colors(const uint8_t *src, int stride, int rows, int cols) {
149 0 : int n = 0, r, c, i, val_count[256];
150 : uint8_t val;
151 0 : memset(val_count, 0, sizeof(val_count));
152 :
153 0 : for (r = 0; r < rows; ++r) {
154 0 : for (c = 0; c < cols; ++c) {
155 0 : val = src[r * stride + c];
156 0 : ++val_count[val];
157 : }
158 : }
159 :
160 0 : for (i = 0; i < 256; ++i) {
161 0 : if (val_count[i]) {
162 0 : ++n;
163 : }
164 : }
165 :
166 0 : return n;
167 : }
168 :
169 : #if CONFIG_PALETTE_DELTA_ENCODING
170 : static int delta_encode_cost(const int *colors, int num, int bit_depth,
171 : int min_val) {
172 : if (num <= 0) return 0;
173 : int bits_cost = bit_depth;
174 : if (num == 1) return bits_cost;
175 : bits_cost += 2;
176 : int max_delta = 0;
177 : int deltas[PALETTE_MAX_SIZE];
178 : const int min_bits = bit_depth - 3;
179 : for (int i = 1; i < num; ++i) {
180 : const int delta = colors[i] - colors[i - 1];
181 : deltas[i - 1] = delta;
182 : assert(delta >= min_val);
183 : if (delta > max_delta) max_delta = delta;
184 : }
185 : int bits_per_delta = AOMMAX(av1_ceil_log2(max_delta + 1 - min_val), min_bits);
186 : assert(bits_per_delta <= bit_depth);
187 : int range = (1 << bit_depth) - colors[0] - min_val;
188 : for (int i = 0; i < num - 1; ++i) {
189 : bits_cost += bits_per_delta;
190 : range -= deltas[i];
191 : bits_per_delta = AOMMIN(bits_per_delta, av1_ceil_log2(range));
192 : }
193 : return bits_cost;
194 : }
195 :
196 : int av1_index_color_cache(const uint16_t *color_cache, int n_cache,
197 : const uint16_t *colors, int n_colors,
198 : uint8_t *cache_color_found, int *out_cache_colors) {
199 : if (n_cache <= 0) {
200 : for (int i = 0; i < n_colors; ++i) out_cache_colors[i] = colors[i];
201 : return n_colors;
202 : }
203 : memset(cache_color_found, 0, n_cache * sizeof(*cache_color_found));
204 : int n_in_cache = 0;
205 : int in_cache_flags[PALETTE_MAX_SIZE];
206 : memset(in_cache_flags, 0, sizeof(in_cache_flags));
207 : for (int i = 0; i < n_cache && n_in_cache < n_colors; ++i) {
208 : for (int j = 0; j < n_colors; ++j) {
209 : if (colors[j] == color_cache[i]) {
210 : in_cache_flags[j] = 1;
211 : cache_color_found[i] = 1;
212 : ++n_in_cache;
213 : break;
214 : }
215 : }
216 : }
217 : int j = 0;
218 : for (int i = 0; i < n_colors; ++i)
219 : if (!in_cache_flags[i]) out_cache_colors[j++] = colors[i];
220 : assert(j == n_colors - n_in_cache);
221 : return j;
222 : }
223 :
224 : int av1_get_palette_delta_bits_v(const PALETTE_MODE_INFO *const pmi,
225 : int bit_depth, int *zero_count,
226 : int *min_bits) {
227 : const int n = pmi->palette_size[1];
228 : const int max_val = 1 << bit_depth;
229 : int max_d = 0;
230 : *min_bits = bit_depth - 4;
231 : *zero_count = 0;
232 : for (int i = 1; i < n; ++i) {
233 : const int delta = pmi->palette_colors[2 * PALETTE_MAX_SIZE + i] -
234 : pmi->palette_colors[2 * PALETTE_MAX_SIZE + i - 1];
235 : const int v = abs(delta);
236 : const int d = AOMMIN(v, max_val - v);
237 : if (d > max_d) max_d = d;
238 : if (d == 0) ++(*zero_count);
239 : }
240 : return AOMMAX(av1_ceil_log2(max_d + 1), *min_bits);
241 : }
242 : #endif // CONFIG_PALETTE_DELTA_ENCODING
243 :
244 0 : int av1_palette_color_cost_y(const PALETTE_MODE_INFO *const pmi,
245 : #if CONFIG_PALETTE_DELTA_ENCODING
246 : uint16_t *color_cache, int n_cache,
247 : #endif // CONFIG_PALETTE_DELTA_ENCODING
248 : int bit_depth) {
249 0 : const int n = pmi->palette_size[0];
250 : #if CONFIG_PALETTE_DELTA_ENCODING
251 : int out_cache_colors[PALETTE_MAX_SIZE];
252 : uint8_t cache_color_found[2 * PALETTE_MAX_SIZE];
253 : const int n_out_cache =
254 : av1_index_color_cache(color_cache, n_cache, pmi->palette_colors, n,
255 : cache_color_found, out_cache_colors);
256 : const int total_bits =
257 : n_cache + delta_encode_cost(out_cache_colors, n_out_cache, bit_depth, 1);
258 : return total_bits * av1_cost_bit(128, 0);
259 : #else
260 0 : return bit_depth * n * av1_cost_bit(128, 0);
261 : #endif // CONFIG_PALETTE_DELTA_ENCODING
262 : }
263 :
264 0 : int av1_palette_color_cost_uv(const PALETTE_MODE_INFO *const pmi,
265 : #if CONFIG_PALETTE_DELTA_ENCODING
266 : uint16_t *color_cache, int n_cache,
267 : #endif // CONFIG_PALETTE_DELTA_ENCODING
268 : int bit_depth) {
269 0 : const int n = pmi->palette_size[1];
270 : #if CONFIG_PALETTE_DELTA_ENCODING
271 : int total_bits = 0;
272 : // U channel palette color cost.
273 : int out_cache_colors[PALETTE_MAX_SIZE];
274 : uint8_t cache_color_found[2 * PALETTE_MAX_SIZE];
275 : const int n_out_cache = av1_index_color_cache(
276 : color_cache, n_cache, pmi->palette_colors + PALETTE_MAX_SIZE, n,
277 : cache_color_found, out_cache_colors);
278 : total_bits +=
279 : n_cache + delta_encode_cost(out_cache_colors, n_out_cache, bit_depth, 0);
280 :
281 : // V channel palette color cost.
282 : int zero_count = 0, min_bits_v = 0;
283 : const int bits_v =
284 : av1_get_palette_delta_bits_v(pmi, bit_depth, &zero_count, &min_bits_v);
285 : const int bits_using_delta =
286 : 2 + bit_depth + (bits_v + 1) * (n - 1) - zero_count;
287 : const int bits_using_raw = bit_depth * n;
288 : total_bits += 1 + AOMMIN(bits_using_delta, bits_using_raw);
289 : return total_bits * av1_cost_bit(128, 0);
290 : #else
291 0 : return 2 * bit_depth * n * av1_cost_bit(128, 0);
292 : #endif // CONFIG_PALETTE_DELTA_ENCODING
293 : }
294 :
295 : #if CONFIG_HIGHBITDEPTH
296 0 : int av1_count_colors_highbd(const uint8_t *src8, int stride, int rows, int cols,
297 : int bit_depth) {
298 0 : int n = 0, r, c, i;
299 : uint16_t val;
300 0 : uint16_t *src = CONVERT_TO_SHORTPTR(src8);
301 : int val_count[1 << 12];
302 :
303 0 : assert(bit_depth <= 12);
304 0 : memset(val_count, 0, (1 << 12) * sizeof(val_count[0]));
305 0 : for (r = 0; r < rows; ++r) {
306 0 : for (c = 0; c < cols; ++c) {
307 0 : val = src[r * stride + c];
308 0 : ++val_count[val];
309 : }
310 : }
311 :
312 0 : for (i = 0; i < (1 << bit_depth); ++i) {
313 0 : if (val_count[i]) {
314 0 : ++n;
315 : }
316 : }
317 :
318 0 : return n;
319 : }
320 : #endif // CONFIG_HIGHBITDEPTH
|