quantization / lm-quant-toolkit /data-vis /plot-quant-grids.R
chen459664's picture
Add files using upload-large-folder tool
998922f verified
# Load required libraries
library(ggplot2)
library(dplyr)
library(safetensors)
library(jsonlite)
library(pracma)
library(patchwork)
get_tensor <- function(
matrix_name,
base_dir,
index_json = "model.safetensors.index.json") {
index_file <- file.path(base_dir, index_json)
model_index <- fromJSON(index_file)
if (exists(matrix_name, model_index$weight_map)) {
st_file <- model_index$weight_map[[matrix_name]]
st_file_fp <- file.path(base_dir, st_file)
tensors <- safe_load_file(st_file_fp)
return(tensors[[matrix_name]])
}
}
get_region <- function(cx, cy, bs, upper_x = 4096, upper_y = 4096) {
sxs <- cx
sxe <- cx + bs
sxe <- if (sxe > upper_x) upper_x else sxe
sys <- cy
sye <- cy + bs
sye <- if (sye > upper_y) upper_y else sye
return(list(sxs = sxs, sxe = sxe, sys = sys, sye = sye))
}
matrix <- "31.self_attn.o_proj"
orig_matrix <- paste0("model.layers.", matrix, ".weight")
base_dir <- "~/.cache/huggingface/hub/models--meta-llama--Llama-2-7b-hf/snapshots/01c7f73d771dfac7d292323805ebc428287df4f9"
base_dir <- path.expand(base_dir)
wo <- get_tensor(orig_matrix, base_dir)
wo <- as.matrix(wo)
bs <- 16
cx <- 0
cy <- 0
ret <- get_region(cx, cy, bs)
wo1 <- wo[ret$sxs:ret$sxe, ret$sys:ret$sye]
# Generate data with mean = 0.25 (middle of [-1, 1.5]) and sd = 0.5
# Then clip to desired range
raw_data <- as.vector(wo1)
data <- data.frame(
x = raw_data
)
# Perform k-means clustering
kmeans_result <- kmeans(data, centers = 16, nstart = 25)
# Add cluster assignments to the data
data$cluster <- as.factor(kmeans_result$cluster)
# Create a data frame for centroids
centroids <- data.frame(
x = kmeans_result$centers[, 1],
y = 0 # Set y to 0 for 1D visualization
)
# Create a jittered y-coordinate for better visualization
data$y <- jitter(rep(0, nrow(data)), amount = 0.3)
# Create the plot
p_kmeans <- ggplot() +
# Plot the points with jittering
geom_point(
data = data,
aes(x = x, y = y, color = cluster),
alpha = 0.6,
size = 3
) +
# Add centroids
geom_point(
data = centroids,
aes(x = x, y = y),
color = "black",
size = 3,
shape = 2
) +
# Add lines to show the actual 1D nature of data
geom_segment(
data = centroids,
aes(x = x, xend = x, y = -0.5, yend = 0.5),
color = "black",
linetype = "dashed"
) +
# Customize the theme and labels
theme_minimal(base_size = 12) +
labs(
# title = "1D K-means Clustering (k=16)",
# subtitle = paste0("Llama2-7b ", matrix),
x = "K-means",
y = ""
) +
theme(
legend.position = "none",
axis.text.x = element_text(size = 12),
axis.text.y = element_blank(),
axis.ticks.y = element_blank()
) +
scale_color_discrete(name = "Cluster")
rtn_grid <- data.frame(
x = linspace(min(data$x), max(data$x), 16),
y = 0 # Set y to 0 for 1D visualization
)
# Create the plot
p_rtn <- ggplot() +
# Plot the points with jittering
geom_point(
data = data,
aes(x = x, y = y, color = cluster),
alpha = 0.6,
size = 3
) +
# Add lines to show the actual 1D nature of data
geom_segment(
data = rtn_grid,
aes(x = x, xend = x, y = -0.5, yend = 0.5),
color = "black",
linetype = "dashed"
) +
# Customize the theme and labels
theme_minimal(base_size = 12) +
labs(
# title = "1D RTN grid",
# subtitle = paste0("Llama2-7b ", matrix),
x = "RTN",
y = ""
) +
theme(
legend.position = "none",
axis.text.x = element_text(size = 12),
axis.text.y = element_blank(),
axis.ticks.y = element_blank()
) +
scale_color_discrete(name = "Cluster")
quantile_grid <- data.frame(
x = quantile(raw_data, probs = linspace(0, 100, 16) / 100),
y = 0 # Set y to 0 for 1D visualization
)
# Create the plot
p_quantile <- ggplot() +
# Plot the points with jittering
geom_point(
data = data,
aes(x = x, y = y, color = cluster),
alpha = 0.6,
size = 3
) +
# Add lines to show the actual 1D nature of data
geom_segment(
data = quantile_grid,
aes(x = x, xend = x, y = -0.5, yend = 0.5),
color = "black",
linetype = "dashed"
) +
# Customize the theme and labels
theme_minimal(base_size = 12) +
labs(
# title = "1D Quantile grid",
# subtitle = paste0("Llama2-7b ", matrix),
x = "Quantile",
y = ""
) +
theme(
legend.position = "none",
axis.text.x = element_text(size = 12),
axis.text.y = element_blank(),
axis.ticks.y = element_blank()
) +
scale_color_discrete(name = "Cluster")
final_plot <- p_rtn / p_kmeans / p_quantile
final_plot
ggsave(
"pdfs/quant-grid-comparison.pdf",
plot = final_plot,
width = 9,
height = 6
)