Initial import

This commit is contained in:
Earlopain 2023-04-05 12:43:13 +02:00
commit d2cf1f0ce5
No known key found for this signature in database
GPG Key ID: 6CFB948E15246897
7 changed files with 2893 additions and 0 deletions

7
.env.sample Normal file
View File

@ -0,0 +1,7 @@
SERVER_ADDR=
PG.USER=
PG.PASSWORD=
PG.HOST=
PG.PORT=
PG.DBNAME=
PG.POOL.MAX_SIZE=

2
.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
.env
/target

2667
Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

22
Cargo.toml Normal file
View File

@ -0,0 +1,22 @@
[package]
name = "autocompleted"
version = "0.1.0"
edition = "2018"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
actix-web = "3"
moka = { version = "0.6.1", features = ["future"] }
config = "0.10.1"
deadpool-postgres = "0.5.0"
derive_more = "0.99.2"
dotenv = "0.15.0"
serde = { version = "1.0.104", features = ["derive"] }
serde_json = "1.0"
tokio-pg-mapper = "0.1"
tokio-pg-mapper-derive = "0.1"
tokio-postgres = "0.5.1"
unicode-normalization = "0.1.19"
log = "0.4"
env_logger = "0.9.0"

8
sql/fetch_tags_a.sql Normal file
View File

@ -0,0 +1,8 @@
SELECT DISTINCT ON (name, post_count) * FROM (
(SELECT tags.id, tags.name, tags.post_count, tags.category, null AS antecedent_name FROM
"tags" WHERE (tags.name LIKE $1 ESCAPE E'\\') AND (post_count > 0) ORDER BY post_count desc LIMIT 10)
UNION ALL
(SELECT tags.id, tags.name, tags.post_count, tags.category, tag_aliases.antecedent_name
FROM "tag_aliases"
INNER JOIN tags ON tags.name = tag_aliases.consequent_name
WHERE (tag_aliases.antecedent_name LIKE $1 ESCAPE E'\\') AND "tag_aliases"."status" IN ('active', 'processing', 'queued') AND (tags.name NOT LIKE $1 ESCAPE E'\\') AND (tag_aliases.post_count > 0) ORDER BY tag_aliases.post_count desc LIMIT 20)) AS unioned_query ORDER BY post_count desc LIMIT 10

1
sql/fetch_tags_b.sql Normal file
View File

@ -0,0 +1 @@
SELECT tags.id, tags.name, tags.post_count, tags.category, null AS antecedent_name FROM "tags" WHERE (tags.name % $1) AND (tags.post_count > 0) ORDER BY trunc(3 * similarity(name, $1)) DESC, post_count DESC, name DESC LIMIT 10

186
src/main.rs Normal file
View File

@ -0,0 +1,186 @@
use actix_web::{dev::HttpResponseBuilder, error, get, http::header, http::StatusCode, HttpResponse, web};
use deadpool_postgres::Pool;
use derive_more::{Display, Error, From};
use moka::future::Cache;
use log::error;
use serde::Deserialize;
mod config {
pub use ::config::ConfigError;
use serde::Deserialize;
#[derive(Deserialize)]
pub struct Config {
pub server_addr: String,
pub pg: deadpool_postgres::Config,
}
impl Config {
pub fn from_env() -> Result<Self, ConfigError> {
let mut cfg = ::config::Config::new();
cfg.merge(::config::Environment::new())?;
cfg.try_into()
}
}
}
mod models {
use serde::{Deserialize, Serialize};
use tokio_pg_mapper_derive::PostgresMapper;
#[derive(Deserialize, PostgresMapper, Serialize)]
#[pg_mapper(table = "tags")] // singular 'user' is a keyword..
pub struct Tag {
pub id: i32,
pub name: String,
pub post_count: i32,
pub category: i16,
pub antecedent_name: Option<String>,
}
}
mod db {
use deadpool_postgres::Client;
use tokio_pg_mapper::FromTokioPostgresRow;
use crate::models::Tag;
fn escape_like(stuff: &String) -> String {
stuff.replace("%", "\\%").replace("_", "\\_").replace("*", "%").replace("\\*", "*")
}
pub async fn get_tags(client: &Client, tag_prefix: &String) -> Result<Vec<Tag>, tokio_postgres::Error> {
let escape_prefix = escape_like(&(tag_prefix.to_owned() + "*"));
let _stmt = "set statement_timeout = 3000";
let stmt = client.prepare(&_stmt).await?;
client.execute(&stmt, &[]).await?;
let _stmt = include_str!("../sql/fetch_tags_a.sql");
let stmt = client.prepare(&_stmt).await?;
let rows = client.query(&stmt, &[&escape_prefix])
.await?.iter()
.map(|row| Tag::from_row_ref(row).unwrap())
.collect::<Vec<Tag>>();
if rows.len() > 0 {
return Ok(rows);
}
let _stmt = include_str!("../sql/fetch_tags_b.sql");
let stmt = client.prepare(&_stmt).await?;
let rows = client.query(&stmt, &[&tag_prefix])
.await?.iter()
.map(|row| Tag::from_row_ref(row).unwrap())
.collect::<Vec<Tag>>();
Ok(rows)
}
}
struct AutocompleteState {
pool: Pool,
cache: Cache<String, String>,
}
#[derive(Debug, Display, Error, From)]
enum AutocompleteError {
#[display(fmt = "bad request")]
BadRequest,
#[display(fmt = "internal error")]
ServerError,
}
impl error::ResponseError for AutocompleteError {
fn error_response(&self) -> HttpResponse {
match *self {
AutocompleteError::BadRequest => HttpResponseBuilder::new(self.status_code())
.set_header(header::CONTENT_TYPE, "application/json; charset=utf-8")
.set_header(header::CACHE_CONTROL, "private; max-age=0").body("{\"error\":\"bad request\"}"),
AutocompleteError::ServerError => HttpResponseBuilder::new(self.status_code())
.set_header(header::CONTENT_TYPE, "application/json; charset=utf-8")
.set_header(header::CACHE_CONTROL, "private; max-age=0").body("{\"error\":\"internal error\"}")
}
}
fn status_code(&self) -> StatusCode {
match *self {
AutocompleteError::BadRequest => StatusCode::BAD_REQUEST,
AutocompleteError::ServerError => StatusCode::INTERNAL_SERVER_ERROR
}
}
}
fn validate_transform_tag(tag: &str) -> Result<String, AutocompleteError> {
use unicode_normalization::UnicodeNormalization;
if tag.len() > 100 {
return Err(AutocompleteError::BadRequest);
}
if tag.len() < 3 {
return Err(AutocompleteError::BadRequest);
}
let tag_str = tag.nfc().collect::<String>().to_lowercase().replace("*", "").replace("%", "").chars().filter(|x| !x.is_whitespace()).collect();
Ok(tag_str)
}
#[derive(Deserialize)]
struct Req {
#[serde(rename(deserialize = "search[name_matches]"))]
tag_prefix: String
}
#[get("/")]
async fn autocomplete(data: web::Data<AutocompleteState>, req: web::Query<Req>) -> Result<HttpResponse, AutocompleteError> {
let prefix: String = validate_transform_tag(req.tag_prefix.as_str())?;
let cached = data.cache.get(&prefix);
return if cached.is_some() {
Ok(HttpResponse::Ok()
.set_header(header::CONTENT_TYPE, "application/json; charset=utf-8")
.set_header(header::CACHE_CONTROL, "public; max-age=604800")
.body(cached.unwrap()))
} else {
let client = match data.pool.get().await {
Ok(x) => x,
Err(x) => {
error!("{}", x.to_string());
return Err(AutocompleteError::ServerError);
}
};
let results = match db::get_tags(&client, &prefix).await {
Ok(x) => x,
Err(x) => {
error!("{}", x.to_string());
return Err(AutocompleteError::ServerError);
}
};
let serialized = serde_json::to_string(&results).unwrap_or_else(|_| "[]".to_string());
let serialized_copy = serialized.clone();
data.cache.insert(prefix, serialized).await;
Ok(HttpResponse::Ok()
.set_header(header::CONTENT_TYPE, "application/json; charset=utf-8")
.set_header(header::CACHE_CONTROL, "public, max-age=604800")
.body(serialized_copy))
};
}
#[actix_web::main]
async fn main() -> std::io::Result<()> {
use actix_web::{App, HttpServer};
use dotenv::dotenv;
use tokio_postgres::NoTls;
use moka::future::CacheBuilder;
use std::time::Duration;
dotenv().ok();
env_logger::init();
let config = crate::config::Config::from_env().unwrap();
let pool = config.pg.create_pool(NoTls).unwrap();
let cache = CacheBuilder::new(15_000)
.time_to_live(Duration::from_secs(6 * 60 * 60))
.build();
HttpServer::new(move || {
App::new()
.data(AutocompleteState {
pool: pool.clone(),
cache: cache.clone(),
})
.service(autocomplete)
}).bind(config.server_addr.clone())?
.run().await
}