refactor: reorganize project structure and fix broken references
- Move scripts to scripts/ directory (roda.sh, prepara_db.py, etc.) - Move shell config to shell/ directory (Caddyfile, auth.py, haloy.yml) - Move basedosdados.duckdb to data/ directory - Update Dockerfile and start.sh with new file paths - Update README.md with correct script paths - Remove Python ask.py (replaced by Rust binary in ask/ask) - Add Rust source files (schema_filter.rs, sql_generator.rs, table_selector.rs) - Remove sentence-transformer dependencies from ask - Move docs and context artifacts to their directories
This commit is contained in:
227
ask/src/main.rs
227
ask/src/main.rs
@@ -1,4 +1,9 @@
|
||||
mod schema_filter;
|
||||
mod sql_generator;
|
||||
mod table_selector;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use chrono::Utc;
|
||||
use crossterm::{
|
||||
event::{
|
||||
DisableBracketedPaste, DisableMouseCapture, EnableBracketedPaste, EnableMouseCapture,
|
||||
@@ -9,14 +14,12 @@ use crossterm::{
|
||||
};
|
||||
use duckdb::Connection;
|
||||
use ratatui::{
|
||||
buffer::Buffer,
|
||||
layout::{Constraint, Direction, Layout, Rect},
|
||||
style::{Color, Modifier, Style},
|
||||
text::{Line, Span},
|
||||
widgets::{Block, Borders, Gauge, Paragraph, Row, Table, TableState, Wrap},
|
||||
Frame, Terminal,
|
||||
};
|
||||
use chrono::Utc;
|
||||
use serde_json::{json, Value};
|
||||
use std::{
|
||||
env, fs,
|
||||
@@ -43,6 +46,10 @@ struct Config {
|
||||
schema: String,
|
||||
db_file: String,
|
||||
prompt_file: String,
|
||||
use_table_selection: bool,
|
||||
embeddings_file: String,
|
||||
schema_json: String,
|
||||
similarity_threshold: f32,
|
||||
}
|
||||
|
||||
enum Phase {
|
||||
@@ -234,10 +241,23 @@ fn spawn_worker(
|
||||
model: String,
|
||||
prompt_file: String,
|
||||
db_file: String,
|
||||
use_table_selection: bool,
|
||||
embeddings_file: String,
|
||||
schema_json: String,
|
||||
similarity_threshold: f32,
|
||||
) -> mpsc::Receiver<WorkerMsg> {
|
||||
let (tx, rx) = mpsc::channel::<WorkerMsg>();
|
||||
std::thread::spawn(
|
||||
move || match ask_model(&question, &schema, &model, &prompt_file) {
|
||||
std::thread::spawn(move || {
|
||||
match ask_model_with_selection(
|
||||
&question,
|
||||
&schema,
|
||||
&model,
|
||||
&prompt_file,
|
||||
use_table_selection,
|
||||
&embeddings_file,
|
||||
&schema_json,
|
||||
similarity_threshold,
|
||||
) {
|
||||
Err(e) => {
|
||||
let err = format!("{:#}", e);
|
||||
log_question(&question, "", false, Some(&err));
|
||||
@@ -257,8 +277,8 @@ fn spawn_worker(
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
);
|
||||
}
|
||||
});
|
||||
rx
|
||||
}
|
||||
|
||||
@@ -270,6 +290,10 @@ fn spawn_retry_worker(
|
||||
model: String,
|
||||
prompt_file: String,
|
||||
db_file: String,
|
||||
use_table_selection: bool,
|
||||
embeddings_file: String,
|
||||
schema_json: String,
|
||||
similarity_threshold: f32,
|
||||
) -> mpsc::Receiver<WorkerMsg> {
|
||||
let retry_q = format!(
|
||||
"{}\n\nO SQL que você gerou falhou com este erro DuckDB:\n```\n{}\n```\n\n\
|
||||
@@ -277,7 +301,17 @@ fn spawn_retry_worker(
|
||||
Corrija o SQL. Retorne APENAS o SQL corrigido, sem explicação.",
|
||||
question, error, failed_sql
|
||||
);
|
||||
spawn_worker(retry_q, schema, model, prompt_file, db_file)
|
||||
spawn_worker(
|
||||
retry_q,
|
||||
schema,
|
||||
model,
|
||||
prompt_file,
|
||||
db_file,
|
||||
use_table_selection,
|
||||
embeddings_file,
|
||||
schema_json,
|
||||
similarity_threshold,
|
||||
)
|
||||
}
|
||||
|
||||
// ── event handling ────────────────────────────────────────────────────────────
|
||||
@@ -327,6 +361,10 @@ impl App {
|
||||
self.config.model.clone(),
|
||||
self.config.prompt_file.clone(),
|
||||
self.config.db_file.clone(),
|
||||
self.config.use_table_selection,
|
||||
self.config.embeddings_file.clone(),
|
||||
self.config.schema_json.clone(),
|
||||
self.config.similarity_threshold,
|
||||
));
|
||||
}
|
||||
|
||||
@@ -398,6 +436,10 @@ impl App {
|
||||
self.config.model.clone(),
|
||||
self.config.prompt_file.clone(),
|
||||
self.config.db_file.clone(),
|
||||
self.config.use_table_selection,
|
||||
self.config.embeddings_file.clone(),
|
||||
self.config.schema_json.clone(),
|
||||
self.config.similarity_threshold,
|
||||
));
|
||||
self.last_sql.clear();
|
||||
} else {
|
||||
@@ -723,7 +765,12 @@ fn draw_content(f: &mut Frame, app: &mut App, area: Rect) {
|
||||
let col_max_widths: Vec<usize> = (0..col_count)
|
||||
.map(|i| {
|
||||
let header_len = cols[i].len();
|
||||
let data_len = rows.iter().filter_map(|r| r.get(i)).map(|c| c.len()).max().unwrap_or(0);
|
||||
let data_len = rows
|
||||
.iter()
|
||||
.filter_map(|r| r.get(i))
|
||||
.map(|c| c.len())
|
||||
.max()
|
||||
.unwrap_or(0);
|
||||
(header_len.max(data_len)).max(min_col_width as usize)
|
||||
})
|
||||
.collect();
|
||||
@@ -732,16 +779,24 @@ fn draw_content(f: &mut Frame, app: &mut App, area: Rect) {
|
||||
let use_wrap = total_needed > available_width as usize;
|
||||
|
||||
if use_wrap {
|
||||
let wrap_width = (available_width as usize / col_count).max(min_col_width as usize);
|
||||
let header_lines: Vec<Line> = cols.iter()
|
||||
let wrap_width =
|
||||
(available_width as usize / col_count).max(min_col_width as usize);
|
||||
let header_lines: Vec<Line> = cols
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, c)| {
|
||||
let wrapped = wrap_text(c, wrap_width);
|
||||
Line::from(wrapped)
|
||||
let spans: Vec<Span> =
|
||||
wrapped.into_iter().map(|s| Span::raw(s)).collect();
|
||||
Line::from(spans)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let max_header_lines = header_lines.iter().map(|l| l.len()).max().unwrap_or(1);
|
||||
let max_header_lines = header_lines
|
||||
.iter()
|
||||
.map(|l| l.spans.len())
|
||||
.max()
|
||||
.unwrap_or(1);
|
||||
|
||||
let mut all_row_lines: Vec<Vec<Line>> = Vec::new();
|
||||
for row in rows {
|
||||
@@ -749,19 +804,19 @@ fn draw_content(f: &mut Frame, app: &mut App, area: Rect) {
|
||||
.map(|i| {
|
||||
let cell = row.get(i).map(|s| s.as_str()).unwrap_or("");
|
||||
let wrapped = wrap_text(cell, wrap_width);
|
||||
Line::from(wrapped)
|
||||
let spans: Vec<Span> =
|
||||
wrapped.into_iter().map(|s| Span::raw(s)).collect();
|
||||
Line::from(spans)
|
||||
})
|
||||
.collect();
|
||||
let max_lines = row_lines.iter().map(|l| l.len()).max().unwrap_or(1);
|
||||
let max_lines = row_lines.iter().map(|l| l.spans.len()).max().unwrap_or(1);
|
||||
all_row_lines.push(row_lines);
|
||||
}
|
||||
|
||||
let selected_idx = table_state.selected().unwrap_or(0);
|
||||
let table_title = format!(" Resultados ({}/{}) ", selected_idx + 1, n);
|
||||
|
||||
let block = Block::default()
|
||||
.borders(Borders::ALL)
|
||||
.title(table_title);
|
||||
let block = Block::default().borders(Borders::ALL).title(table_title);
|
||||
|
||||
let area = chunks[2];
|
||||
f.render_widget(block, area);
|
||||
@@ -778,29 +833,32 @@ fn draw_content(f: &mut Frame, app: &mut App, area: Rect) {
|
||||
|
||||
let start_row = if n > visible_rows as usize {
|
||||
let scroll = selected_idx as i32 - visible_rows as i32 / 2;
|
||||
scroll.max(0) as usize.min(n.saturating_sub(visible_rows as usize))
|
||||
(scroll.max(0) as usize).min(n.saturating_sub(visible_rows as usize))
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
let header_bg = Style::default().fg(Color::Yellow).add_modifier(Modifier::BOLD);
|
||||
let header_bg = Style::default()
|
||||
.fg(Color::Yellow)
|
||||
.add_modifier(Modifier::BOLD);
|
||||
for (col_idx, header_line) in header_lines.iter().enumerate() {
|
||||
let col_x = inner_area.x + (col_idx as u16) * (wrap_width as u16 + 1);
|
||||
let col_width = wrap_width as u16;
|
||||
for (line_idx, line) in header_line.iter().enumerate() {
|
||||
for (line_idx, span) in header_line.spans.iter().enumerate() {
|
||||
let y = inner_area.y + line_idx as u16;
|
||||
if y >= inner_area.y + inner_area.height {
|
||||
break;
|
||||
}
|
||||
let spans: Vec<Span> = line.spans.iter().map(|s| {
|
||||
Span::styled(s.content.clone(), header_bg)
|
||||
}).collect();
|
||||
f.render_widget(Paragraph::new(Line::from(spans)), Rect {
|
||||
x: col_x,
|
||||
y,
|
||||
width: col_width,
|
||||
height: 1,
|
||||
});
|
||||
let styled_span = Span::styled(span.content.clone(), header_bg);
|
||||
f.render_widget(
|
||||
Paragraph::new(Line::from(styled_span)),
|
||||
Rect {
|
||||
x: col_x,
|
||||
y,
|
||||
width: col_width,
|
||||
height: 1,
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -811,7 +869,9 @@ fn draw_content(f: &mut Frame, app: &mut App, area: Rect) {
|
||||
}
|
||||
let is_selected = row_idx == selected_idx;
|
||||
let row_style = if is_selected {
|
||||
Style::default().bg(Color::DarkGray).add_modifier(Modifier::BOLD)
|
||||
Style::default()
|
||||
.bg(Color::DarkGray)
|
||||
.add_modifier(Modifier::BOLD)
|
||||
} else {
|
||||
Style::default()
|
||||
};
|
||||
@@ -820,20 +880,21 @@ fn draw_content(f: &mut Frame, app: &mut App, area: Rect) {
|
||||
for (col_idx, cell_lines) in row_lines.iter().enumerate() {
|
||||
let col_x = inner_area.x + (col_idx as u16) * (wrap_width as u16 + 1);
|
||||
let col_width = wrap_width as u16;
|
||||
for (line_idx, line) in cell_lines.iter().enumerate() {
|
||||
for (line_idx, span) in cell_lines.spans.iter().enumerate() {
|
||||
let cell_y = y + line_idx as u16;
|
||||
if cell_y >= inner_area.y + inner_area.height {
|
||||
break;
|
||||
}
|
||||
let spans: Vec<Span> = line.spans.iter().map(|s| {
|
||||
Span::styled(s.content.clone(), row_style)
|
||||
}).collect();
|
||||
f.render_widget(Paragraph::new(Line::from(spans)), Rect {
|
||||
x: col_x,
|
||||
y: cell_y,
|
||||
width: col_width,
|
||||
height: 1,
|
||||
});
|
||||
let styled_span = Span::styled(span.content.clone(), row_style);
|
||||
f.render_widget(
|
||||
Paragraph::new(Line::from(styled_span)),
|
||||
Rect {
|
||||
x: col_x,
|
||||
y: cell_y,
|
||||
width: col_width,
|
||||
height: 1,
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -850,7 +911,8 @@ fn draw_content(f: &mut Frame, app: &mut App, area: Rect) {
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let col_widths: Vec<Constraint> = cols.iter()
|
||||
let col_widths: Vec<Constraint> = cols
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, _)| {
|
||||
let w = col_max_widths[i] as u16;
|
||||
@@ -1008,6 +1070,55 @@ fn ask_model(question: &str, schema: &str, model: &str, prompt_file: &str) -> Re
|
||||
Ok(ensure_sql(&sql))
|
||||
}
|
||||
|
||||
fn ask_model_with_selection(
|
||||
question: &str,
|
||||
_full_schema: &str,
|
||||
model: &str,
|
||||
prompt_file: &str,
|
||||
use_selection: bool,
|
||||
embeddings_file: &str,
|
||||
schema_json: &str,
|
||||
similarity_threshold: f32,
|
||||
) -> Result<String> {
|
||||
let prompt_template = fs::read_to_string(prompt_file)
|
||||
.with_context(|| format!("Não foi possível ler o prompt: {}", prompt_file))?;
|
||||
|
||||
let (schema_to_use, selected_tables) = if use_selection {
|
||||
match table_selector::select_tables_from_question(
|
||||
question,
|
||||
embeddings_file,
|
||||
similarity_threshold,
|
||||
) {
|
||||
Ok(table_ids) => {
|
||||
eprintln!(
|
||||
"=> Selecionadas {} tables relevantes: {:?}",
|
||||
table_ids.len(),
|
||||
table_ids
|
||||
);
|
||||
let schema_filter = schema_filter::SchemaFilter::new(schema_json)?;
|
||||
let filtered_schema = schema_filter.filter_tables(&table_ids);
|
||||
(filtered_schema, Some(table_ids))
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!(
|
||||
"=> Aviso: falha na seleção de tables ({}), usando schema completo",
|
||||
e
|
||||
);
|
||||
let schema_filter = schema_filter::SchemaFilter::new(schema_json)?;
|
||||
(schema_filter.full_schema_text(), None)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let schema_filter = schema_filter::SchemaFilter::new(schema_json)?;
|
||||
(schema_filter.full_schema_text(), None)
|
||||
};
|
||||
|
||||
let generator = sql_generator::create_sql_generator()?;
|
||||
let sql = generator.generate(question, &schema_to_use, &prompt_template)?;
|
||||
|
||||
Ok(ensure_sql(&sql))
|
||||
}
|
||||
|
||||
fn ask_gemini(question: &str, system_prompt: &str, model: &str) -> Result<String> {
|
||||
let key = env::var("GEMINI_API_KEY").context("GEMINI_API_KEY não definida")?;
|
||||
let url = format!(
|
||||
@@ -1309,6 +1420,12 @@ VARIÁVEIS DE AMBIENTE
|
||||
OPENROUTER_API_KEY necessária para modelos OpenRouter
|
||||
GEMINI_MODEL modelo padrão (sobrescrito por --model)
|
||||
SCHEMA_FILE DDL do schema [context/schema_compact_inline.txt]
|
||||
SCHEMA_JSON full schema JSON [context/basedosdados-schema.json]
|
||||
EMBEDDINGS_FILE table embeddings [context/table_embeddings.json]
|
||||
TOP_K_TABLES número de tables a selecionar [5]
|
||||
SQL_GENERATOR sql generator: sqlcoder|gemini|openrouter [gemini]
|
||||
OLLAMA_MODEL modelo ollama [sqlcoder]
|
||||
OLLAMA_HOST host ollama [http://localhost:11434]
|
||||
PROMPT_FILE prompt do sistema [ask/system_prompt.md]
|
||||
DB_FILE banco DuckDB [basedosdados.duckdb]
|
||||
"#
|
||||
@@ -1321,7 +1438,18 @@ VARIÁVEIS DE AMBIENTE
|
||||
});
|
||||
let schema_file =
|
||||
env::var("SCHEMA_FILE").unwrap_or_else(|_| "context/schema_compact_inline.txt".into());
|
||||
let db_file = env::var("DB_FILE").unwrap_or_else(|_| "basedosdados.duckdb".into());
|
||||
let schema_json =
|
||||
env::var("SCHEMA_JSON").unwrap_or_else(|_| "context/basedosdados-schema.json".into());
|
||||
let embeddings_file =
|
||||
env::var("EMBEDDINGS_FILE").unwrap_or_else(|_| "context/table_embeddings.json".into());
|
||||
let similarity_threshold = env::var("SIMILARITY_THRESHOLD")
|
||||
.ok()
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or(0.35);
|
||||
let use_table_selection = env::var("USE_TABLE_SELECTION")
|
||||
.map(|v| v != "false" && v != "0")
|
||||
.unwrap_or(true);
|
||||
let db_file = env::var("DB_FILE").unwrap_or_else(|_| "data/basedosdados.duckdb".into());
|
||||
let prompt_file = env::var("PROMPT_FILE").unwrap_or_else(|_| "ask/system_prompt.md".into());
|
||||
let schema = fs::read_to_string(&schema_file)
|
||||
.with_context(|| format!("Não foi possível ler o schema: {}", schema_file))?;
|
||||
@@ -1333,6 +1461,10 @@ VARIÁVEIS DE AMBIENTE
|
||||
schema,
|
||||
db_file,
|
||||
prompt_file,
|
||||
use_table_selection,
|
||||
embeddings_file,
|
||||
schema_json,
|
||||
similarity_threshold,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1341,7 +1473,16 @@ VARIÁVEIS DE AMBIENTE
|
||||
eprintln!("\nModel: {}\nPergunta: {}\n", model, question);
|
||||
|
||||
let t0 = Instant::now();
|
||||
let sql = ask_model(&question, &schema, &model, &prompt_file)?;
|
||||
let sql = ask_model_with_selection(
|
||||
&question,
|
||||
&schema,
|
||||
&model,
|
||||
&prompt_file,
|
||||
use_table_selection,
|
||||
&embeddings_file,
|
||||
&schema_json,
|
||||
similarity_threshold,
|
||||
)?;
|
||||
eprintln!("=> SQL gerado em {}", fmt_duration(t0.elapsed()));
|
||||
print_sql_box(&sql);
|
||||
|
||||
|
||||
135
ask/src/schema_filter.rs
Normal file
135
ask/src/schema_filter.rs
Normal file
@@ -0,0 +1,135 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashSet;
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct Column {
|
||||
pub name: String,
|
||||
#[serde(rename = "type")]
|
||||
pub col_type: String,
|
||||
pub description: Option<String>,
|
||||
}
|
||||
|
||||
pub type TableColumns = Vec<Column>;
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct FullSchema {
|
||||
#[serde(flatten)]
|
||||
pub datasets:
|
||||
std::collections::HashMap<String, std::collections::HashMap<String, TableColumns>>,
|
||||
}
|
||||
|
||||
pub struct SchemaFilter {
|
||||
schema: FullSchema,
|
||||
}
|
||||
|
||||
impl SchemaFilter {
|
||||
pub fn new<P: AsRef<Path>>(schema_path: P) -> anyhow::Result<Self> {
|
||||
let content = fs::read_to_string(schema_path)?;
|
||||
let schema: FullSchema = serde_json::from_str(&content)?;
|
||||
Ok(Self { schema })
|
||||
}
|
||||
|
||||
pub fn filter_tables(&self, table_ids: &[String]) -> String {
|
||||
let selected: HashSet<String> = table_ids.iter().cloned().collect();
|
||||
let mut lines = Vec::new();
|
||||
|
||||
lines.push("# Base dos Dados — Filtered Schema".to_string());
|
||||
lines.push(
|
||||
"# Legend: V=VARCHAR I=INT D=DOUBLE Dt=DATE B=BOOLEAN Dec=DECIMAL Ts=TIMESTAMP Ti=TIME"
|
||||
.to_string(),
|
||||
);
|
||||
lines.push("# Format: dataset.table: col:TYPE description".to_string());
|
||||
lines.push(String::new());
|
||||
|
||||
for (dataset, tables) in &self.schema.datasets {
|
||||
for (table, columns) in tables {
|
||||
let full_id = format!("{}.{}", dataset, table);
|
||||
if selected.contains(&full_id) {
|
||||
let col_str = columns
|
||||
.iter()
|
||||
.map(|c| {
|
||||
let desc = c.description.as_deref().unwrap_or("");
|
||||
if desc.is_empty() {
|
||||
format!("{}:{}", c.name, type_abbrev(&c.col_type))
|
||||
} else {
|
||||
format!("{}:{} {}", c.name, type_abbrev(&c.col_type), desc)
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join(" ");
|
||||
|
||||
lines.push(format!("{}: {}", full_id, col_str));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
lines.join("\n")
|
||||
}
|
||||
|
||||
pub fn full_schema_text(&self) -> String {
|
||||
let mut lines = Vec::new();
|
||||
|
||||
lines.push("# Base dos Dados — Full Schema".to_string());
|
||||
lines.push(
|
||||
"# Legend: V=VARCHAR I=INT D=DOUBLE Dt=DATE B=BOOLEAN Dec=DECIMAL Ts=TIMESTAMP Ti=TIME"
|
||||
.to_string(),
|
||||
);
|
||||
lines.push("# Format: dataset.table: col:TYPE description".to_string());
|
||||
lines.push(String::new());
|
||||
|
||||
for (dataset, tables) in &self.schema.datasets {
|
||||
for (table, columns) in tables {
|
||||
let full_id = format!("{}.{}", dataset, table);
|
||||
let col_str = columns
|
||||
.iter()
|
||||
.map(|c| {
|
||||
let desc = c.description.as_deref().unwrap_or("");
|
||||
if desc.is_empty() {
|
||||
format!("{}:{}", c.name, type_abbrev(&c.col_type))
|
||||
} else {
|
||||
format!("{}:{} {}", c.name, type_abbrev(&c.col_type), desc)
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join(" ");
|
||||
|
||||
lines.push(format!("{}: {}", full_id, col_str));
|
||||
}
|
||||
}
|
||||
|
||||
lines.join("\n")
|
||||
}
|
||||
|
||||
pub fn dataset_count(&self) -> usize {
|
||||
self.schema.datasets.len()
|
||||
}
|
||||
|
||||
pub fn table_count(&self) -> usize {
|
||||
self.schema.datasets.values().map(|t| t.len()).sum()
|
||||
}
|
||||
}
|
||||
|
||||
fn type_abbrev(full_type: &str) -> String {
|
||||
let upper = full_type.to_uppercase();
|
||||
if upper.contains("VARCHAR") || upper.contains("STRING") {
|
||||
"V".to_string()
|
||||
} else if upper.contains("INT") {
|
||||
"I".to_string()
|
||||
} else if upper.contains("DOUBLE") || upper.contains("FLOAT") {
|
||||
"D".to_string()
|
||||
} else if upper.contains("DATE") && !upper.contains("TIMESTAMP") {
|
||||
"Dt".to_string()
|
||||
} else if upper.contains("TIMESTAMP") {
|
||||
"Ts".to_string()
|
||||
} else if upper.contains("TIME") {
|
||||
"Ti".to_string()
|
||||
} else if upper.contains("BOOLEAN") {
|
||||
"B".to_string()
|
||||
} else if upper.contains("DECIMAL") {
|
||||
"Dec".to_string()
|
||||
} else {
|
||||
full_type.to_string()
|
||||
}
|
||||
}
|
||||
207
ask/src/sql_generator.rs
Normal file
207
ask/src/sql_generator.rs
Normal file
@@ -0,0 +1,207 @@
|
||||
use anyhow::{Context, Result};
|
||||
use serde_json::Value;
|
||||
use std::env;
|
||||
|
||||
pub trait SqlGenerator: Send + Sync {
|
||||
fn generate(&self, question: &str, schema: &str, prompt_template: &str) -> Result<String>;
|
||||
}
|
||||
|
||||
pub fn create_sql_generator() -> Result<Box<dyn SqlGenerator>> {
|
||||
let generator_type = env::var("SQL_GENERATOR").unwrap_or_else(|_| "gemini".to_string());
|
||||
|
||||
match generator_type.as_str() {
|
||||
"sqlcoder" => Ok(Box::new(SqlCoderGenerator::new()?)),
|
||||
"openrouter" => Ok(Box::new(OpenRouterGenerator::new()?)),
|
||||
"gemini" => Ok(Box::new(GeminiGenerator::new()?)),
|
||||
_ => anyhow::bail!(
|
||||
"Unknown SQL_GENERATOR: {}. Use: sqlcoder, gemini, or openrouter",
|
||||
generator_type
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
pub struct GeminiGenerator {
|
||||
model: String,
|
||||
api_key: String,
|
||||
}
|
||||
|
||||
impl GeminiGenerator {
|
||||
pub fn new() -> Result<Self> {
|
||||
let model = env::var("GEMINI_MODEL").unwrap_or_else(|_| "gemini-flash-latest".to_string());
|
||||
let api_key = env::var("GEMINI_API_KEY").context("GEMINI_API_KEY not defined")?;
|
||||
Ok(Self { model, api_key })
|
||||
}
|
||||
}
|
||||
|
||||
impl SqlGenerator for GeminiGenerator {
|
||||
fn generate(&self, question: &str, schema: &str, prompt_template: &str) -> Result<String> {
|
||||
let url = format!(
|
||||
"https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent",
|
||||
self.model
|
||||
);
|
||||
|
||||
let system_prompt = format!("{}\n\nSchema DDL:\n\n{}", prompt_template.trim(), schema);
|
||||
|
||||
let payload = serde_json::json!({
|
||||
"system_instruction": { "parts": [{ "text": system_prompt }] },
|
||||
"contents": [{ "parts": [{ "text": question }] }]
|
||||
});
|
||||
|
||||
let client = reqwest::blocking::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(300))
|
||||
.build()?;
|
||||
|
||||
let resp = client
|
||||
.post(&url)
|
||||
.header("Content-Type", "application/json")
|
||||
.header("X-goog-api-key", &self.api_key)
|
||||
.json(&payload)
|
||||
.send()
|
||||
.context("Gemini HTTP request failed")?;
|
||||
|
||||
let status = resp.status();
|
||||
let body: Value = resp.json().context("Failed to parse Gemini response")?;
|
||||
|
||||
if !status.is_success() {
|
||||
anyhow::bail!("Gemini API error {}: {}", status, body);
|
||||
}
|
||||
|
||||
let text = body["candidates"][0]["content"]["parts"][0]["text"]
|
||||
.as_str()
|
||||
.context("Unexpected Gemini response format")?
|
||||
.trim()
|
||||
.to_string();
|
||||
|
||||
Ok(strip_fences(&text))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct OpenRouterGenerator {
|
||||
model: String,
|
||||
api_key: String,
|
||||
}
|
||||
|
||||
impl OpenRouterGenerator {
|
||||
pub fn new() -> Result<Self> {
|
||||
let model =
|
||||
env::var("OPENROUTER_MODEL").unwrap_or_else(|_| "openai/gpt-4o-mini".to_string());
|
||||
let api_key = env::var("OPENROUTER_API_KEY").context("OPENROUTER_API_KEY not defined")?;
|
||||
Ok(Self { model, api_key })
|
||||
}
|
||||
}
|
||||
|
||||
impl SqlGenerator for OpenRouterGenerator {
|
||||
fn generate(&self, question: &str, schema: &str, prompt_template: &str) -> Result<String> {
|
||||
let url = "https://openrouter.ai/api/v1/chat/completions";
|
||||
|
||||
let system_prompt = format!("{}\n\nSchema DDL:\n\n{}", prompt_template.trim(), schema);
|
||||
|
||||
let payload = serde_json::json!({
|
||||
"model": self.model,
|
||||
"messages": [
|
||||
{ "role": "system", "content": system_prompt },
|
||||
{ "role": "user", "content": question }
|
||||
]
|
||||
});
|
||||
|
||||
let client = reqwest::blocking::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(300))
|
||||
.build()?;
|
||||
|
||||
let resp = client
|
||||
.post(url)
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.header("HTTP-Referer", "https://basedosdados.org")
|
||||
.header("X-Title", "Base dos Dados Ask")
|
||||
.json(&payload)
|
||||
.send()
|
||||
.context("OpenRouter HTTP request failed")?;
|
||||
|
||||
let status = resp.status();
|
||||
let body: Value = resp.json().context("Failed to parse OpenRouter response")?;
|
||||
|
||||
if !status.is_success() {
|
||||
anyhow::bail!("OpenRouter API error {}: {}", status, body);
|
||||
}
|
||||
|
||||
let text = body["choices"][0]["message"]["content"]
|
||||
.as_str()
|
||||
.context("Unexpected OpenRouter response format")?
|
||||
.trim()
|
||||
.to_string();
|
||||
|
||||
Ok(strip_fences(&text))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SqlCoderGenerator {
|
||||
model: String,
|
||||
host: String,
|
||||
}
|
||||
|
||||
impl SqlCoderGenerator {
|
||||
pub fn new() -> Result<Self> {
|
||||
let model = env::var("OLLAMA_MODEL").unwrap_or_else(|_| "sqlcoder".to_string());
|
||||
let host = env::var("OLLAMA_HOST").unwrap_or_else(|_| "http://localhost:11434".to_string());
|
||||
Ok(Self { model, host })
|
||||
}
|
||||
}
|
||||
|
||||
impl SqlGenerator for SqlCoderGenerator {
|
||||
fn generate(&self, question: &str, schema: &str, prompt_template: &str) -> Result<String> {
|
||||
let url = format!("{}/api/generate", self.host);
|
||||
|
||||
let full_prompt = format!(
|
||||
"{}\n\nSchema DDL:\n\n{}\n\nQuestion: {}\n\nSQL:",
|
||||
prompt_template.trim(),
|
||||
schema,
|
||||
question
|
||||
);
|
||||
|
||||
let payload = serde_json::json!({
|
||||
"model": self.model,
|
||||
"prompt": full_prompt,
|
||||
"stream": false
|
||||
});
|
||||
|
||||
let client = reqwest::blocking::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(300))
|
||||
.build()?;
|
||||
|
||||
let resp = client
|
||||
.post(&url)
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&payload)
|
||||
.send()
|
||||
.context("Ollama HTTP request failed")?;
|
||||
|
||||
let status = resp.status();
|
||||
let body: Value = resp.json().context("Failed to parse Ollama response")?;
|
||||
|
||||
if !status.is_success() {
|
||||
anyhow::bail!("Ollama API error {}: {}", status, body);
|
||||
}
|
||||
|
||||
let text = body["response"]
|
||||
.as_str()
|
||||
.context("Unexpected Ollama response format")?
|
||||
.trim()
|
||||
.to_string();
|
||||
|
||||
Ok(strip_fences(&text))
|
||||
}
|
||||
}
|
||||
|
||||
fn strip_fences(text: &str) -> String {
|
||||
let text = text.trim();
|
||||
if text.starts_with("```sql") {
|
||||
let end = text.find("```").unwrap_or(text.len());
|
||||
text[5..end].trim().to_string()
|
||||
} else if text.starts_with("```") {
|
||||
let end = text[3..].find("```").map(|i| i + 3).unwrap_or(text.len());
|
||||
text[3..end].trim().to_string()
|
||||
} else {
|
||||
text.to_string()
|
||||
}
|
||||
}
|
||||
146
ask/src/table_selector.rs
Normal file
146
ask/src/table_selector.rs
Normal file
@@ -0,0 +1,146 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
|
||||
const DEFAULT_SIMILARITY_THRESHOLD: f32 = 0.35;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TableEmbedding {
|
||||
pub id: String,
|
||||
pub text: String,
|
||||
pub embedding: Vec<f32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct EmbeddingsIndex {
|
||||
pub tables: Vec<TableEmbedding>,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
pub struct TableSelector {
|
||||
tables: Vec<TableEmbedding>,
|
||||
threshold: f32,
|
||||
}
|
||||
|
||||
impl TableSelector {
|
||||
pub fn new<P: AsRef<Path>>(embeddings_path: P, threshold: f32) -> anyhow::Result<Self> {
|
||||
let content = fs::read_to_string(embeddings_path)?;
|
||||
let index: EmbeddingsIndex = serde_json::from_str(&content)?;
|
||||
Ok(Self {
|
||||
tables: index.tables,
|
||||
threshold,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn select_tables(
|
||||
&self,
|
||||
question: &str,
|
||||
model: &dyn QuestionEmbedder,
|
||||
) -> anyhow::Result<Vec<String>> {
|
||||
let question_embedding = model.embed(question)?;
|
||||
|
||||
let mut similarities: Vec<(usize, f32)> = self
|
||||
.tables
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, table)| {
|
||||
let sim = cosine_similarity(&question_embedding, &table.embedding);
|
||||
(i, sim)
|
||||
})
|
||||
.collect();
|
||||
|
||||
similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
||||
|
||||
let selected: Vec<String> = similarities
|
||||
.into_iter()
|
||||
.filter(|(_, sim)| *sim >= self.threshold)
|
||||
.map(|(i, sim)| {
|
||||
eprintln!(" {} (similarity: {:.3})", self.tables[i].id, sim);
|
||||
self.tables[i].id.clone()
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(selected)
|
||||
}
|
||||
|
||||
pub fn get_table_texts(&self, table_ids: &[String]) -> Vec<String> {
|
||||
table_ids
|
||||
.iter()
|
||||
.filter_map(|id| self.tables.iter().find(|t| &t.id == id))
|
||||
.map(|t| t.text.clone())
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn table_count(&self) -> usize {
|
||||
self.tables.len()
|
||||
}
|
||||
}
|
||||
|
||||
pub trait QuestionEmbedder: Send + Sync {
|
||||
fn embed(&self, text: &str) -> anyhow::Result<Vec<f32>>;
|
||||
}
|
||||
|
||||
pub struct LocalEmbedder {
|
||||
model_path: String,
|
||||
}
|
||||
|
||||
impl LocalEmbedder {
|
||||
pub fn new(model_path: String) -> Self {
|
||||
Self { model_path }
|
||||
}
|
||||
}
|
||||
|
||||
impl QuestionEmbedder for LocalEmbedder {
|
||||
fn embed(&self, text: &str) -> anyhow::Result<Vec<f32>> {
|
||||
use std::process::Command;
|
||||
|
||||
let output = Command::new("python3")
|
||||
.args([
|
||||
"-c",
|
||||
&format!(
|
||||
r#"
|
||||
import json
|
||||
from sentence_transformers import SentenceTransformer
|
||||
model = SentenceTransformer('{}')
|
||||
emb = model.encode('{}', convert_to_numpy=True)
|
||||
print(json.dumps([float(x) for x in emb]))
|
||||
"#,
|
||||
self.model_path,
|
||||
text.replace("'", "\\'")
|
||||
),
|
||||
])
|
||||
.output()?;
|
||||
|
||||
if !output.status.success() {
|
||||
let err = String::from_utf8_lossy(&output.stderr);
|
||||
anyhow::bail!("Embedding generation failed: {}", err);
|
||||
}
|
||||
|
||||
let output_str = String::from_utf8_lossy(&output.stdout);
|
||||
let floats: Vec<f32> = serde_json::from_str(&output_str)?;
|
||||
|
||||
Ok(floats)
|
||||
}
|
||||
}
|
||||
|
||||
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
|
||||
let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
|
||||
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
|
||||
if norm_a == 0.0 || norm_b == 0.0 {
|
||||
0.0
|
||||
} else {
|
||||
dot_product / (norm_a * norm_b)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn select_tables_from_question(
|
||||
question: &str,
|
||||
embeddings_path: &str,
|
||||
threshold: f32,
|
||||
) -> anyhow::Result<Vec<String>> {
|
||||
let selector = TableSelector::new(embeddings_path, threshold)?;
|
||||
let embedder = LocalEmbedder::new("all-MiniLM-L6-v2".to_string());
|
||||
selector.select_tables(question, &embedder)
|
||||
}
|
||||
Reference in New Issue
Block a user