refactor: reorganize project structure and fix broken references

- Move scripts to scripts/ directory (roda.sh, prepara_db.py, etc.)
- Move shell config to shell/ directory (Caddyfile, auth.py, haloy.yml)
- Move basedosdados.duckdb to data/ directory
- Update Dockerfile and start.sh with new file paths
- Update README.md with correct script paths
- Remove Python ask.py (replaced by Rust binary in ask/ask)
- Add Rust source files (schema_filter.rs, sql_generator.rs, table_selector.rs)
- Remove sentence-transformer dependencies from ask
- Move docs and context artifacts to their directories
This commit is contained in:
2026-03-29 20:46:27 +02:00
parent 02cb13362c
commit ed5fa6756e
43 changed files with 302366 additions and 1093 deletions

View File

@@ -1,4 +1,9 @@
mod schema_filter;
mod sql_generator;
mod table_selector;
use anyhow::{Context, Result};
use chrono::Utc;
use crossterm::{
event::{
DisableBracketedPaste, DisableMouseCapture, EnableBracketedPaste, EnableMouseCapture,
@@ -9,14 +14,12 @@ use crossterm::{
};
use duckdb::Connection;
use ratatui::{
buffer::Buffer,
layout::{Constraint, Direction, Layout, Rect},
style::{Color, Modifier, Style},
text::{Line, Span},
widgets::{Block, Borders, Gauge, Paragraph, Row, Table, TableState, Wrap},
Frame, Terminal,
};
use chrono::Utc;
use serde_json::{json, Value};
use std::{
env, fs,
@@ -43,6 +46,10 @@ struct Config {
schema: String,
db_file: String,
prompt_file: String,
use_table_selection: bool,
embeddings_file: String,
schema_json: String,
similarity_threshold: f32,
}
enum Phase {
@@ -234,10 +241,23 @@ fn spawn_worker(
model: String,
prompt_file: String,
db_file: String,
use_table_selection: bool,
embeddings_file: String,
schema_json: String,
similarity_threshold: f32,
) -> mpsc::Receiver<WorkerMsg> {
let (tx, rx) = mpsc::channel::<WorkerMsg>();
std::thread::spawn(
move || match ask_model(&question, &schema, &model, &prompt_file) {
std::thread::spawn(move || {
match ask_model_with_selection(
&question,
&schema,
&model,
&prompt_file,
use_table_selection,
&embeddings_file,
&schema_json,
similarity_threshold,
) {
Err(e) => {
let err = format!("{:#}", e);
log_question(&question, "", false, Some(&err));
@@ -257,8 +277,8 @@ fn spawn_worker(
}
}
}
},
);
}
});
rx
}
@@ -270,6 +290,10 @@ fn spawn_retry_worker(
model: String,
prompt_file: String,
db_file: String,
use_table_selection: bool,
embeddings_file: String,
schema_json: String,
similarity_threshold: f32,
) -> mpsc::Receiver<WorkerMsg> {
let retry_q = format!(
"{}\n\nO SQL que você gerou falhou com este erro DuckDB:\n```\n{}\n```\n\n\
@@ -277,7 +301,17 @@ fn spawn_retry_worker(
Corrija o SQL. Retorne APENAS o SQL corrigido, sem explicação.",
question, error, failed_sql
);
spawn_worker(retry_q, schema, model, prompt_file, db_file)
spawn_worker(
retry_q,
schema,
model,
prompt_file,
db_file,
use_table_selection,
embeddings_file,
schema_json,
similarity_threshold,
)
}
// ── event handling ────────────────────────────────────────────────────────────
@@ -327,6 +361,10 @@ impl App {
self.config.model.clone(),
self.config.prompt_file.clone(),
self.config.db_file.clone(),
self.config.use_table_selection,
self.config.embeddings_file.clone(),
self.config.schema_json.clone(),
self.config.similarity_threshold,
));
}
@@ -398,6 +436,10 @@ impl App {
self.config.model.clone(),
self.config.prompt_file.clone(),
self.config.db_file.clone(),
self.config.use_table_selection,
self.config.embeddings_file.clone(),
self.config.schema_json.clone(),
self.config.similarity_threshold,
));
self.last_sql.clear();
} else {
@@ -723,7 +765,12 @@ fn draw_content(f: &mut Frame, app: &mut App, area: Rect) {
let col_max_widths: Vec<usize> = (0..col_count)
.map(|i| {
let header_len = cols[i].len();
let data_len = rows.iter().filter_map(|r| r.get(i)).map(|c| c.len()).max().unwrap_or(0);
let data_len = rows
.iter()
.filter_map(|r| r.get(i))
.map(|c| c.len())
.max()
.unwrap_or(0);
(header_len.max(data_len)).max(min_col_width as usize)
})
.collect();
@@ -732,16 +779,24 @@ fn draw_content(f: &mut Frame, app: &mut App, area: Rect) {
let use_wrap = total_needed > available_width as usize;
if use_wrap {
let wrap_width = (available_width as usize / col_count).max(min_col_width as usize);
let header_lines: Vec<Line> = cols.iter()
let wrap_width =
(available_width as usize / col_count).max(min_col_width as usize);
let header_lines: Vec<Line> = cols
.iter()
.enumerate()
.map(|(i, c)| {
let wrapped = wrap_text(c, wrap_width);
Line::from(wrapped)
let spans: Vec<Span> =
wrapped.into_iter().map(|s| Span::raw(s)).collect();
Line::from(spans)
})
.collect();
let max_header_lines = header_lines.iter().map(|l| l.len()).max().unwrap_or(1);
let max_header_lines = header_lines
.iter()
.map(|l| l.spans.len())
.max()
.unwrap_or(1);
let mut all_row_lines: Vec<Vec<Line>> = Vec::new();
for row in rows {
@@ -749,19 +804,19 @@ fn draw_content(f: &mut Frame, app: &mut App, area: Rect) {
.map(|i| {
let cell = row.get(i).map(|s| s.as_str()).unwrap_or("");
let wrapped = wrap_text(cell, wrap_width);
Line::from(wrapped)
let spans: Vec<Span> =
wrapped.into_iter().map(|s| Span::raw(s)).collect();
Line::from(spans)
})
.collect();
let max_lines = row_lines.iter().map(|l| l.len()).max().unwrap_or(1);
let max_lines = row_lines.iter().map(|l| l.spans.len()).max().unwrap_or(1);
all_row_lines.push(row_lines);
}
let selected_idx = table_state.selected().unwrap_or(0);
let table_title = format!(" Resultados ({}/{}) ", selected_idx + 1, n);
let block = Block::default()
.borders(Borders::ALL)
.title(table_title);
let block = Block::default().borders(Borders::ALL).title(table_title);
let area = chunks[2];
f.render_widget(block, area);
@@ -778,29 +833,32 @@ fn draw_content(f: &mut Frame, app: &mut App, area: Rect) {
let start_row = if n > visible_rows as usize {
let scroll = selected_idx as i32 - visible_rows as i32 / 2;
scroll.max(0) as usize.min(n.saturating_sub(visible_rows as usize))
(scroll.max(0) as usize).min(n.saturating_sub(visible_rows as usize))
} else {
0
};
let header_bg = Style::default().fg(Color::Yellow).add_modifier(Modifier::BOLD);
let header_bg = Style::default()
.fg(Color::Yellow)
.add_modifier(Modifier::BOLD);
for (col_idx, header_line) in header_lines.iter().enumerate() {
let col_x = inner_area.x + (col_idx as u16) * (wrap_width as u16 + 1);
let col_width = wrap_width as u16;
for (line_idx, line) in header_line.iter().enumerate() {
for (line_idx, span) in header_line.spans.iter().enumerate() {
let y = inner_area.y + line_idx as u16;
if y >= inner_area.y + inner_area.height {
break;
}
let spans: Vec<Span> = line.spans.iter().map(|s| {
Span::styled(s.content.clone(), header_bg)
}).collect();
f.render_widget(Paragraph::new(Line::from(spans)), Rect {
x: col_x,
y,
width: col_width,
height: 1,
});
let styled_span = Span::styled(span.content.clone(), header_bg);
f.render_widget(
Paragraph::new(Line::from(styled_span)),
Rect {
x: col_x,
y,
width: col_width,
height: 1,
},
);
}
}
@@ -811,7 +869,9 @@ fn draw_content(f: &mut Frame, app: &mut App, area: Rect) {
}
let is_selected = row_idx == selected_idx;
let row_style = if is_selected {
Style::default().bg(Color::DarkGray).add_modifier(Modifier::BOLD)
Style::default()
.bg(Color::DarkGray)
.add_modifier(Modifier::BOLD)
} else {
Style::default()
};
@@ -820,20 +880,21 @@ fn draw_content(f: &mut Frame, app: &mut App, area: Rect) {
for (col_idx, cell_lines) in row_lines.iter().enumerate() {
let col_x = inner_area.x + (col_idx as u16) * (wrap_width as u16 + 1);
let col_width = wrap_width as u16;
for (line_idx, line) in cell_lines.iter().enumerate() {
for (line_idx, span) in cell_lines.spans.iter().enumerate() {
let cell_y = y + line_idx as u16;
if cell_y >= inner_area.y + inner_area.height {
break;
}
let spans: Vec<Span> = line.spans.iter().map(|s| {
Span::styled(s.content.clone(), row_style)
}).collect();
f.render_widget(Paragraph::new(Line::from(spans)), Rect {
x: col_x,
y: cell_y,
width: col_width,
height: 1,
});
let styled_span = Span::styled(span.content.clone(), row_style);
f.render_widget(
Paragraph::new(Line::from(styled_span)),
Rect {
x: col_x,
y: cell_y,
width: col_width,
height: 1,
},
);
}
}
@@ -850,7 +911,8 @@ fn draw_content(f: &mut Frame, app: &mut App, area: Rect) {
}
}
} else {
let col_widths: Vec<Constraint> = cols.iter()
let col_widths: Vec<Constraint> = cols
.iter()
.enumerate()
.map(|(i, _)| {
let w = col_max_widths[i] as u16;
@@ -1008,6 +1070,55 @@ fn ask_model(question: &str, schema: &str, model: &str, prompt_file: &str) -> Re
Ok(ensure_sql(&sql))
}
fn ask_model_with_selection(
question: &str,
_full_schema: &str,
model: &str,
prompt_file: &str,
use_selection: bool,
embeddings_file: &str,
schema_json: &str,
similarity_threshold: f32,
) -> Result<String> {
let prompt_template = fs::read_to_string(prompt_file)
.with_context(|| format!("Não foi possível ler o prompt: {}", prompt_file))?;
let (schema_to_use, selected_tables) = if use_selection {
match table_selector::select_tables_from_question(
question,
embeddings_file,
similarity_threshold,
) {
Ok(table_ids) => {
eprintln!(
"=> Selecionadas {} tables relevantes: {:?}",
table_ids.len(),
table_ids
);
let schema_filter = schema_filter::SchemaFilter::new(schema_json)?;
let filtered_schema = schema_filter.filter_tables(&table_ids);
(filtered_schema, Some(table_ids))
}
Err(e) => {
eprintln!(
"=> Aviso: falha na seleção de tables ({}), usando schema completo",
e
);
let schema_filter = schema_filter::SchemaFilter::new(schema_json)?;
(schema_filter.full_schema_text(), None)
}
}
} else {
let schema_filter = schema_filter::SchemaFilter::new(schema_json)?;
(schema_filter.full_schema_text(), None)
};
let generator = sql_generator::create_sql_generator()?;
let sql = generator.generate(question, &schema_to_use, &prompt_template)?;
Ok(ensure_sql(&sql))
}
fn ask_gemini(question: &str, system_prompt: &str, model: &str) -> Result<String> {
let key = env::var("GEMINI_API_KEY").context("GEMINI_API_KEY não definida")?;
let url = format!(
@@ -1309,6 +1420,12 @@ VARIÁVEIS DE AMBIENTE
OPENROUTER_API_KEY necessária para modelos OpenRouter
GEMINI_MODEL modelo padrão (sobrescrito por --model)
SCHEMA_FILE DDL do schema [context/schema_compact_inline.txt]
SCHEMA_JSON full schema JSON [context/basedosdados-schema.json]
EMBEDDINGS_FILE table embeddings [context/table_embeddings.json]
TOP_K_TABLES número de tables a selecionar [5]
SQL_GENERATOR sql generator: sqlcoder|gemini|openrouter [gemini]
OLLAMA_MODEL modelo ollama [sqlcoder]
OLLAMA_HOST host ollama [http://localhost:11434]
PROMPT_FILE prompt do sistema [ask/system_prompt.md]
DB_FILE banco DuckDB [basedosdados.duckdb]
"#
@@ -1321,7 +1438,18 @@ VARIÁVEIS DE AMBIENTE
});
let schema_file =
env::var("SCHEMA_FILE").unwrap_or_else(|_| "context/schema_compact_inline.txt".into());
let db_file = env::var("DB_FILE").unwrap_or_else(|_| "basedosdados.duckdb".into());
let schema_json =
env::var("SCHEMA_JSON").unwrap_or_else(|_| "context/basedosdados-schema.json".into());
let embeddings_file =
env::var("EMBEDDINGS_FILE").unwrap_or_else(|_| "context/table_embeddings.json".into());
let similarity_threshold = env::var("SIMILARITY_THRESHOLD")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(0.35);
let use_table_selection = env::var("USE_TABLE_SELECTION")
.map(|v| v != "false" && v != "0")
.unwrap_or(true);
let db_file = env::var("DB_FILE").unwrap_or_else(|_| "data/basedosdados.duckdb".into());
let prompt_file = env::var("PROMPT_FILE").unwrap_or_else(|_| "ask/system_prompt.md".into());
let schema = fs::read_to_string(&schema_file)
.with_context(|| format!("Não foi possível ler o schema: {}", schema_file))?;
@@ -1333,6 +1461,10 @@ VARIÁVEIS DE AMBIENTE
schema,
db_file,
prompt_file,
use_table_selection,
embeddings_file,
schema_json,
similarity_threshold,
});
}
@@ -1341,7 +1473,16 @@ VARIÁVEIS DE AMBIENTE
eprintln!("\nModel: {}\nPergunta: {}\n", model, question);
let t0 = Instant::now();
let sql = ask_model(&question, &schema, &model, &prompt_file)?;
let sql = ask_model_with_selection(
&question,
&schema,
&model,
&prompt_file,
use_table_selection,
&embeddings_file,
&schema_json,
similarity_threshold,
)?;
eprintln!("=> SQL gerado em {}", fmt_duration(t0.elapsed()));
print_sql_box(&sql);

135
ask/src/schema_filter.rs Normal file
View File

@@ -0,0 +1,135 @@
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use std::fs;
use std::path::Path;
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Column {
pub name: String,
#[serde(rename = "type")]
pub col_type: String,
pub description: Option<String>,
}
pub type TableColumns = Vec<Column>;
#[derive(Debug, Clone, Deserialize)]
pub struct FullSchema {
#[serde(flatten)]
pub datasets:
std::collections::HashMap<String, std::collections::HashMap<String, TableColumns>>,
}
pub struct SchemaFilter {
schema: FullSchema,
}
impl SchemaFilter {
pub fn new<P: AsRef<Path>>(schema_path: P) -> anyhow::Result<Self> {
let content = fs::read_to_string(schema_path)?;
let schema: FullSchema = serde_json::from_str(&content)?;
Ok(Self { schema })
}
pub fn filter_tables(&self, table_ids: &[String]) -> String {
let selected: HashSet<String> = table_ids.iter().cloned().collect();
let mut lines = Vec::new();
lines.push("# Base dos Dados — Filtered Schema".to_string());
lines.push(
"# Legend: V=VARCHAR I=INT D=DOUBLE Dt=DATE B=BOOLEAN Dec=DECIMAL Ts=TIMESTAMP Ti=TIME"
.to_string(),
);
lines.push("# Format: dataset.table: col:TYPE description".to_string());
lines.push(String::new());
for (dataset, tables) in &self.schema.datasets {
for (table, columns) in tables {
let full_id = format!("{}.{}", dataset, table);
if selected.contains(&full_id) {
let col_str = columns
.iter()
.map(|c| {
let desc = c.description.as_deref().unwrap_or("");
if desc.is_empty() {
format!("{}:{}", c.name, type_abbrev(&c.col_type))
} else {
format!("{}:{} {}", c.name, type_abbrev(&c.col_type), desc)
}
})
.collect::<Vec<_>>()
.join(" ");
lines.push(format!("{}: {}", full_id, col_str));
}
}
}
lines.join("\n")
}
pub fn full_schema_text(&self) -> String {
let mut lines = Vec::new();
lines.push("# Base dos Dados — Full Schema".to_string());
lines.push(
"# Legend: V=VARCHAR I=INT D=DOUBLE Dt=DATE B=BOOLEAN Dec=DECIMAL Ts=TIMESTAMP Ti=TIME"
.to_string(),
);
lines.push("# Format: dataset.table: col:TYPE description".to_string());
lines.push(String::new());
for (dataset, tables) in &self.schema.datasets {
for (table, columns) in tables {
let full_id = format!("{}.{}", dataset, table);
let col_str = columns
.iter()
.map(|c| {
let desc = c.description.as_deref().unwrap_or("");
if desc.is_empty() {
format!("{}:{}", c.name, type_abbrev(&c.col_type))
} else {
format!("{}:{} {}", c.name, type_abbrev(&c.col_type), desc)
}
})
.collect::<Vec<_>>()
.join(" ");
lines.push(format!("{}: {}", full_id, col_str));
}
}
lines.join("\n")
}
pub fn dataset_count(&self) -> usize {
self.schema.datasets.len()
}
pub fn table_count(&self) -> usize {
self.schema.datasets.values().map(|t| t.len()).sum()
}
}
fn type_abbrev(full_type: &str) -> String {
let upper = full_type.to_uppercase();
if upper.contains("VARCHAR") || upper.contains("STRING") {
"V".to_string()
} else if upper.contains("INT") {
"I".to_string()
} else if upper.contains("DOUBLE") || upper.contains("FLOAT") {
"D".to_string()
} else if upper.contains("DATE") && !upper.contains("TIMESTAMP") {
"Dt".to_string()
} else if upper.contains("TIMESTAMP") {
"Ts".to_string()
} else if upper.contains("TIME") {
"Ti".to_string()
} else if upper.contains("BOOLEAN") {
"B".to_string()
} else if upper.contains("DECIMAL") {
"Dec".to_string()
} else {
full_type.to_string()
}
}

207
ask/src/sql_generator.rs Normal file
View File

@@ -0,0 +1,207 @@
use anyhow::{Context, Result};
use serde_json::Value;
use std::env;
pub trait SqlGenerator: Send + Sync {
fn generate(&self, question: &str, schema: &str, prompt_template: &str) -> Result<String>;
}
pub fn create_sql_generator() -> Result<Box<dyn SqlGenerator>> {
let generator_type = env::var("SQL_GENERATOR").unwrap_or_else(|_| "gemini".to_string());
match generator_type.as_str() {
"sqlcoder" => Ok(Box::new(SqlCoderGenerator::new()?)),
"openrouter" => Ok(Box::new(OpenRouterGenerator::new()?)),
"gemini" => Ok(Box::new(GeminiGenerator::new()?)),
_ => anyhow::bail!(
"Unknown SQL_GENERATOR: {}. Use: sqlcoder, gemini, or openrouter",
generator_type
),
}
}
pub struct GeminiGenerator {
model: String,
api_key: String,
}
impl GeminiGenerator {
pub fn new() -> Result<Self> {
let model = env::var("GEMINI_MODEL").unwrap_or_else(|_| "gemini-flash-latest".to_string());
let api_key = env::var("GEMINI_API_KEY").context("GEMINI_API_KEY not defined")?;
Ok(Self { model, api_key })
}
}
impl SqlGenerator for GeminiGenerator {
fn generate(&self, question: &str, schema: &str, prompt_template: &str) -> Result<String> {
let url = format!(
"https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent",
self.model
);
let system_prompt = format!("{}\n\nSchema DDL:\n\n{}", prompt_template.trim(), schema);
let payload = serde_json::json!({
"system_instruction": { "parts": [{ "text": system_prompt }] },
"contents": [{ "parts": [{ "text": question }] }]
});
let client = reqwest::blocking::Client::builder()
.timeout(std::time::Duration::from_secs(300))
.build()?;
let resp = client
.post(&url)
.header("Content-Type", "application/json")
.header("X-goog-api-key", &self.api_key)
.json(&payload)
.send()
.context("Gemini HTTP request failed")?;
let status = resp.status();
let body: Value = resp.json().context("Failed to parse Gemini response")?;
if !status.is_success() {
anyhow::bail!("Gemini API error {}: {}", status, body);
}
let text = body["candidates"][0]["content"]["parts"][0]["text"]
.as_str()
.context("Unexpected Gemini response format")?
.trim()
.to_string();
Ok(strip_fences(&text))
}
}
pub struct OpenRouterGenerator {
model: String,
api_key: String,
}
impl OpenRouterGenerator {
pub fn new() -> Result<Self> {
let model =
env::var("OPENROUTER_MODEL").unwrap_or_else(|_| "openai/gpt-4o-mini".to_string());
let api_key = env::var("OPENROUTER_API_KEY").context("OPENROUTER_API_KEY not defined")?;
Ok(Self { model, api_key })
}
}
impl SqlGenerator for OpenRouterGenerator {
fn generate(&self, question: &str, schema: &str, prompt_template: &str) -> Result<String> {
let url = "https://openrouter.ai/api/v1/chat/completions";
let system_prompt = format!("{}\n\nSchema DDL:\n\n{}", prompt_template.trim(), schema);
let payload = serde_json::json!({
"model": self.model,
"messages": [
{ "role": "system", "content": system_prompt },
{ "role": "user", "content": question }
]
});
let client = reqwest::blocking::Client::builder()
.timeout(std::time::Duration::from_secs(300))
.build()?;
let resp = client
.post(url)
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", self.api_key))
.header("HTTP-Referer", "https://basedosdados.org")
.header("X-Title", "Base dos Dados Ask")
.json(&payload)
.send()
.context("OpenRouter HTTP request failed")?;
let status = resp.status();
let body: Value = resp.json().context("Failed to parse OpenRouter response")?;
if !status.is_success() {
anyhow::bail!("OpenRouter API error {}: {}", status, body);
}
let text = body["choices"][0]["message"]["content"]
.as_str()
.context("Unexpected OpenRouter response format")?
.trim()
.to_string();
Ok(strip_fences(&text))
}
}
pub struct SqlCoderGenerator {
model: String,
host: String,
}
impl SqlCoderGenerator {
pub fn new() -> Result<Self> {
let model = env::var("OLLAMA_MODEL").unwrap_or_else(|_| "sqlcoder".to_string());
let host = env::var("OLLAMA_HOST").unwrap_or_else(|_| "http://localhost:11434".to_string());
Ok(Self { model, host })
}
}
impl SqlGenerator for SqlCoderGenerator {
fn generate(&self, question: &str, schema: &str, prompt_template: &str) -> Result<String> {
let url = format!("{}/api/generate", self.host);
let full_prompt = format!(
"{}\n\nSchema DDL:\n\n{}\n\nQuestion: {}\n\nSQL:",
prompt_template.trim(),
schema,
question
);
let payload = serde_json::json!({
"model": self.model,
"prompt": full_prompt,
"stream": false
});
let client = reqwest::blocking::Client::builder()
.timeout(std::time::Duration::from_secs(300))
.build()?;
let resp = client
.post(&url)
.header("Content-Type", "application/json")
.json(&payload)
.send()
.context("Ollama HTTP request failed")?;
let status = resp.status();
let body: Value = resp.json().context("Failed to parse Ollama response")?;
if !status.is_success() {
anyhow::bail!("Ollama API error {}: {}", status, body);
}
let text = body["response"]
.as_str()
.context("Unexpected Ollama response format")?
.trim()
.to_string();
Ok(strip_fences(&text))
}
}
fn strip_fences(text: &str) -> String {
let text = text.trim();
if text.starts_with("```sql") {
let end = text.find("```").unwrap_or(text.len());
text[5..end].trim().to_string()
} else if text.starts_with("```") {
let end = text[3..].find("```").map(|i| i + 3).unwrap_or(text.len());
text[3..end].trim().to_string()
} else {
text.to_string()
}
}

146
ask/src/table_selector.rs Normal file
View File

@@ -0,0 +1,146 @@
use serde::{Deserialize, Serialize};
use std::fs;
use std::path::Path;
const DEFAULT_SIMILARITY_THRESHOLD: f32 = 0.35;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TableEmbedding {
pub id: String,
pub text: String,
pub embedding: Vec<f32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingsIndex {
pub tables: Vec<TableEmbedding>,
pub model: String,
}
pub struct TableSelector {
tables: Vec<TableEmbedding>,
threshold: f32,
}
impl TableSelector {
pub fn new<P: AsRef<Path>>(embeddings_path: P, threshold: f32) -> anyhow::Result<Self> {
let content = fs::read_to_string(embeddings_path)?;
let index: EmbeddingsIndex = serde_json::from_str(&content)?;
Ok(Self {
tables: index.tables,
threshold,
})
}
pub fn select_tables(
&self,
question: &str,
model: &dyn QuestionEmbedder,
) -> anyhow::Result<Vec<String>> {
let question_embedding = model.embed(question)?;
let mut similarities: Vec<(usize, f32)> = self
.tables
.iter()
.enumerate()
.map(|(i, table)| {
let sim = cosine_similarity(&question_embedding, &table.embedding);
(i, sim)
})
.collect();
similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let selected: Vec<String> = similarities
.into_iter()
.filter(|(_, sim)| *sim >= self.threshold)
.map(|(i, sim)| {
eprintln!(" {} (similarity: {:.3})", self.tables[i].id, sim);
self.tables[i].id.clone()
})
.collect();
Ok(selected)
}
pub fn get_table_texts(&self, table_ids: &[String]) -> Vec<String> {
table_ids
.iter()
.filter_map(|id| self.tables.iter().find(|t| &t.id == id))
.map(|t| t.text.clone())
.collect()
}
pub fn table_count(&self) -> usize {
self.tables.len()
}
}
pub trait QuestionEmbedder: Send + Sync {
fn embed(&self, text: &str) -> anyhow::Result<Vec<f32>>;
}
pub struct LocalEmbedder {
model_path: String,
}
impl LocalEmbedder {
pub fn new(model_path: String) -> Self {
Self { model_path }
}
}
impl QuestionEmbedder for LocalEmbedder {
fn embed(&self, text: &str) -> anyhow::Result<Vec<f32>> {
use std::process::Command;
let output = Command::new("python3")
.args([
"-c",
&format!(
r#"
import json
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('{}')
emb = model.encode('{}', convert_to_numpy=True)
print(json.dumps([float(x) for x in emb]))
"#,
self.model_path,
text.replace("'", "\\'")
),
])
.output()?;
if !output.status.success() {
let err = String::from_utf8_lossy(&output.stderr);
anyhow::bail!("Embedding generation failed: {}", err);
}
let output_str = String::from_utf8_lossy(&output.stdout);
let floats: Vec<f32> = serde_json::from_str(&output_str)?;
Ok(floats)
}
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
0.0
} else {
dot_product / (norm_a * norm_b)
}
}
pub fn select_tables_from_question(
question: &str,
embeddings_path: &str,
threshold: f32,
) -> anyhow::Result<Vec<String>> {
let selector = TableSelector::new(embeddings_path, threshold)?;
let embedder = LocalEmbedder::new("all-MiniLM-L6-v2".to_string());
selector.select_tables(question, &embedder)
}