| |
| library(ggplot2) |
| library(dplyr) |
| library(safetensors) |
| library(jsonlite) |
| library(pracma) |
| library(patchwork) |
|
|
| get_tensor <- function( |
| matrix_name, |
| base_dir, |
| index_json = "model.safetensors.index.json") { |
| index_file <- file.path(base_dir, index_json) |
| model_index <- fromJSON(index_file) |
|
|
| if (exists(matrix_name, model_index$weight_map)) { |
| st_file <- model_index$weight_map[[matrix_name]] |
| st_file_fp <- file.path(base_dir, st_file) |
| tensors <- safe_load_file(st_file_fp) |
| return(tensors[[matrix_name]]) |
| } |
| } |
|
|
| get_region <- function(cx, cy, bs, upper_x = 4096, upper_y = 4096) { |
| sxs <- cx |
| sxe <- cx + bs |
| sxe <- if (sxe > upper_x) upper_x else sxe |
| sys <- cy |
| sye <- cy + bs |
| sye <- if (sye > upper_y) upper_y else sye |
| return(list(sxs = sxs, sxe = sxe, sys = sys, sye = sye)) |
| } |
|
|
|
|
| matrix <- "31.self_attn.o_proj" |
| orig_matrix <- paste0("model.layers.", matrix, ".weight") |
| base_dir <- "~/.cache/huggingface/hub/models--meta-llama--Llama-2-7b-hf/snapshots/01c7f73d771dfac7d292323805ebc428287df4f9" |
| base_dir <- path.expand(base_dir) |
| wo <- get_tensor(orig_matrix, base_dir) |
| wo <- as.matrix(wo) |
|
|
|
|
| bs <- 16 |
| cx <- 0 |
| cy <- 0 |
| ret <- get_region(cx, cy, bs) |
| wo1 <- wo[ret$sxs:ret$sxe, ret$sys:ret$sye] |
|
|
| |
| |
| raw_data <- as.vector(wo1) |
| data <- data.frame( |
| x = raw_data |
| ) |
|
|
| |
| kmeans_result <- kmeans(data, centers = 16, nstart = 25) |
|
|
| |
| data$cluster <- as.factor(kmeans_result$cluster) |
|
|
| |
| centroids <- data.frame( |
| x = kmeans_result$centers[, 1], |
| y = 0 |
| ) |
|
|
| |
| data$y <- jitter(rep(0, nrow(data)), amount = 0.3) |
|
|
| |
| p_kmeans <- ggplot() + |
| |
| geom_point( |
| data = data, |
| aes(x = x, y = y, color = cluster), |
| alpha = 0.6, |
| size = 3 |
| ) + |
| |
| geom_point( |
| data = centroids, |
| aes(x = x, y = y), |
| color = "black", |
| size = 3, |
| shape = 2 |
| ) + |
| |
| geom_segment( |
| data = centroids, |
| aes(x = x, xend = x, y = -0.5, yend = 0.5), |
| color = "black", |
| linetype = "dashed" |
| ) + |
| |
| theme_minimal(base_size = 12) + |
| labs( |
| |
| |
| x = "K-means", |
| y = "" |
| ) + |
| theme( |
| legend.position = "none", |
| axis.text.x = element_text(size = 12), |
| axis.text.y = element_blank(), |
| axis.ticks.y = element_blank() |
| ) + |
| scale_color_discrete(name = "Cluster") |
|
|
| rtn_grid <- data.frame( |
| x = linspace(min(data$x), max(data$x), 16), |
| y = 0 |
| ) |
| |
| p_rtn <- ggplot() + |
| |
| geom_point( |
| data = data, |
| aes(x = x, y = y, color = cluster), |
| alpha = 0.6, |
| size = 3 |
| ) + |
| |
| geom_segment( |
| data = rtn_grid, |
| aes(x = x, xend = x, y = -0.5, yend = 0.5), |
| color = "black", |
| linetype = "dashed" |
| ) + |
| |
| theme_minimal(base_size = 12) + |
| labs( |
| |
| |
| x = "RTN", |
| y = "" |
| ) + |
| theme( |
| legend.position = "none", |
| axis.text.x = element_text(size = 12), |
| axis.text.y = element_blank(), |
| axis.ticks.y = element_blank() |
| ) + |
| scale_color_discrete(name = "Cluster") |
|
|
| quantile_grid <- data.frame( |
| x = quantile(raw_data, probs = linspace(0, 100, 16) / 100), |
| y = 0 |
| ) |
| |
| p_quantile <- ggplot() + |
| |
| geom_point( |
| data = data, |
| aes(x = x, y = y, color = cluster), |
| alpha = 0.6, |
| size = 3 |
| ) + |
| |
| geom_segment( |
| data = quantile_grid, |
| aes(x = x, xend = x, y = -0.5, yend = 0.5), |
| color = "black", |
| linetype = "dashed" |
| ) + |
| |
| theme_minimal(base_size = 12) + |
| labs( |
| |
| |
| x = "Quantile", |
| y = "" |
| ) + |
| theme( |
| legend.position = "none", |
| axis.text.x = element_text(size = 12), |
| axis.text.y = element_blank(), |
| axis.ticks.y = element_blank() |
| ) + |
| scale_color_discrete(name = "Cluster") |
|
|
| final_plot <- p_rtn / p_kmeans / p_quantile |
| final_plot |
| ggsave( |
| "pdfs/quant-grid-comparison.pdf", |
| plot = final_plot, |
| width = 9, |
| height = 6 |
| ) |
|
|