mirror of
https://github.com/e621ng/autocompleted.git
synced 2025-03-04 03:03:02 -05:00
Initial import
This commit is contained in:
commit
d2cf1f0ce5
7
.env.sample
Normal file
7
.env.sample
Normal file
@ -0,0 +1,7 @@
|
||||
SERVER_ADDR=
|
||||
PG.USER=
|
||||
PG.PASSWORD=
|
||||
PG.HOST=
|
||||
PG.PORT=
|
||||
PG.DBNAME=
|
||||
PG.POOL.MAX_SIZE=
|
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
.env
|
||||
/target
|
2667
Cargo.lock
generated
Normal file
2667
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
22
Cargo.toml
Normal file
22
Cargo.toml
Normal file
@ -0,0 +1,22 @@
|
||||
[package]
|
||||
name = "autocompleted"
|
||||
version = "0.1.0"
|
||||
edition = "2018"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
actix-web = "3"
|
||||
moka = { version = "0.6.1", features = ["future"] }
|
||||
config = "0.10.1"
|
||||
deadpool-postgres = "0.5.0"
|
||||
derive_more = "0.99.2"
|
||||
dotenv = "0.15.0"
|
||||
serde = { version = "1.0.104", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
tokio-pg-mapper = "0.1"
|
||||
tokio-pg-mapper-derive = "0.1"
|
||||
tokio-postgres = "0.5.1"
|
||||
unicode-normalization = "0.1.19"
|
||||
log = "0.4"
|
||||
env_logger = "0.9.0"
|
8
sql/fetch_tags_a.sql
Normal file
8
sql/fetch_tags_a.sql
Normal file
@ -0,0 +1,8 @@
|
||||
SELECT DISTINCT ON (name, post_count) * FROM (
|
||||
(SELECT tags.id, tags.name, tags.post_count, tags.category, null AS antecedent_name FROM
|
||||
"tags" WHERE (tags.name LIKE $1 ESCAPE E'\\') AND (post_count > 0) ORDER BY post_count desc LIMIT 10)
|
||||
UNION ALL
|
||||
(SELECT tags.id, tags.name, tags.post_count, tags.category, tag_aliases.antecedent_name
|
||||
FROM "tag_aliases"
|
||||
INNER JOIN tags ON tags.name = tag_aliases.consequent_name
|
||||
WHERE (tag_aliases.antecedent_name LIKE $1 ESCAPE E'\\') AND "tag_aliases"."status" IN ('active', 'processing', 'queued') AND (tags.name NOT LIKE $1 ESCAPE E'\\') AND (tag_aliases.post_count > 0) ORDER BY tag_aliases.post_count desc LIMIT 20)) AS unioned_query ORDER BY post_count desc LIMIT 10
|
1
sql/fetch_tags_b.sql
Normal file
1
sql/fetch_tags_b.sql
Normal file
@ -0,0 +1 @@
|
||||
SELECT tags.id, tags.name, tags.post_count, tags.category, null AS antecedent_name FROM "tags" WHERE (tags.name % $1) AND (tags.post_count > 0) ORDER BY trunc(3 * similarity(name, $1)) DESC, post_count DESC, name DESC LIMIT 10
|
186
src/main.rs
Normal file
186
src/main.rs
Normal file
@ -0,0 +1,186 @@
|
||||
use actix_web::{dev::HttpResponseBuilder, error, get, http::header, http::StatusCode, HttpResponse, web};
|
||||
use deadpool_postgres::Pool;
|
||||
use derive_more::{Display, Error, From};
|
||||
use moka::future::Cache;
|
||||
use log::error;
|
||||
use serde::Deserialize;
|
||||
|
||||
mod config {
|
||||
pub use ::config::ConfigError;
|
||||
use serde::Deserialize;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct Config {
|
||||
pub server_addr: String,
|
||||
pub pg: deadpool_postgres::Config,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn from_env() -> Result<Self, ConfigError> {
|
||||
let mut cfg = ::config::Config::new();
|
||||
cfg.merge(::config::Environment::new())?;
|
||||
cfg.try_into()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mod models {
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio_pg_mapper_derive::PostgresMapper;
|
||||
|
||||
#[derive(Deserialize, PostgresMapper, Serialize)]
|
||||
#[pg_mapper(table = "tags")] // singular 'user' is a keyword..
|
||||
pub struct Tag {
|
||||
pub id: i32,
|
||||
pub name: String,
|
||||
pub post_count: i32,
|
||||
pub category: i16,
|
||||
pub antecedent_name: Option<String>,
|
||||
}
|
||||
}
|
||||
|
||||
mod db {
|
||||
use deadpool_postgres::Client;
|
||||
use tokio_pg_mapper::FromTokioPostgresRow;
|
||||
|
||||
use crate::models::Tag;
|
||||
|
||||
fn escape_like(stuff: &String) -> String {
|
||||
stuff.replace("%", "\\%").replace("_", "\\_").replace("*", "%").replace("\\*", "*")
|
||||
}
|
||||
|
||||
pub async fn get_tags(client: &Client, tag_prefix: &String) -> Result<Vec<Tag>, tokio_postgres::Error> {
|
||||
let escape_prefix = escape_like(&(tag_prefix.to_owned() + "*"));
|
||||
let _stmt = "set statement_timeout = 3000";
|
||||
let stmt = client.prepare(&_stmt).await?;
|
||||
client.execute(&stmt, &[]).await?;
|
||||
let _stmt = include_str!("../sql/fetch_tags_a.sql");
|
||||
let stmt = client.prepare(&_stmt).await?;
|
||||
let rows = client.query(&stmt, &[&escape_prefix])
|
||||
.await?.iter()
|
||||
.map(|row| Tag::from_row_ref(row).unwrap())
|
||||
.collect::<Vec<Tag>>();
|
||||
if rows.len() > 0 {
|
||||
return Ok(rows);
|
||||
}
|
||||
let _stmt = include_str!("../sql/fetch_tags_b.sql");
|
||||
let stmt = client.prepare(&_stmt).await?;
|
||||
let rows = client.query(&stmt, &[&tag_prefix])
|
||||
.await?.iter()
|
||||
.map(|row| Tag::from_row_ref(row).unwrap())
|
||||
.collect::<Vec<Tag>>();
|
||||
Ok(rows)
|
||||
}
|
||||
}
|
||||
|
||||
struct AutocompleteState {
|
||||
pool: Pool,
|
||||
cache: Cache<String, String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Display, Error, From)]
|
||||
enum AutocompleteError {
|
||||
#[display(fmt = "bad request")]
|
||||
BadRequest,
|
||||
#[display(fmt = "internal error")]
|
||||
ServerError,
|
||||
}
|
||||
|
||||
impl error::ResponseError for AutocompleteError {
|
||||
fn error_response(&self) -> HttpResponse {
|
||||
match *self {
|
||||
AutocompleteError::BadRequest => HttpResponseBuilder::new(self.status_code())
|
||||
.set_header(header::CONTENT_TYPE, "application/json; charset=utf-8")
|
||||
.set_header(header::CACHE_CONTROL, "private; max-age=0").body("{\"error\":\"bad request\"}"),
|
||||
AutocompleteError::ServerError => HttpResponseBuilder::new(self.status_code())
|
||||
.set_header(header::CONTENT_TYPE, "application/json; charset=utf-8")
|
||||
.set_header(header::CACHE_CONTROL, "private; max-age=0").body("{\"error\":\"internal error\"}")
|
||||
}
|
||||
}
|
||||
|
||||
fn status_code(&self) -> StatusCode {
|
||||
match *self {
|
||||
AutocompleteError::BadRequest => StatusCode::BAD_REQUEST,
|
||||
AutocompleteError::ServerError => StatusCode::INTERNAL_SERVER_ERROR
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn validate_transform_tag(tag: &str) -> Result<String, AutocompleteError> {
|
||||
use unicode_normalization::UnicodeNormalization;
|
||||
if tag.len() > 100 {
|
||||
return Err(AutocompleteError::BadRequest);
|
||||
}
|
||||
if tag.len() < 3 {
|
||||
return Err(AutocompleteError::BadRequest);
|
||||
}
|
||||
let tag_str = tag.nfc().collect::<String>().to_lowercase().replace("*", "").replace("%", "").chars().filter(|x| !x.is_whitespace()).collect();
|
||||
Ok(tag_str)
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Req {
|
||||
#[serde(rename(deserialize = "search[name_matches]"))]
|
||||
tag_prefix: String
|
||||
}
|
||||
|
||||
#[get("/")]
|
||||
async fn autocomplete(data: web::Data<AutocompleteState>, req: web::Query<Req>) -> Result<HttpResponse, AutocompleteError> {
|
||||
let prefix: String = validate_transform_tag(req.tag_prefix.as_str())?;
|
||||
let cached = data.cache.get(&prefix);
|
||||
return if cached.is_some() {
|
||||
Ok(HttpResponse::Ok()
|
||||
.set_header(header::CONTENT_TYPE, "application/json; charset=utf-8")
|
||||
.set_header(header::CACHE_CONTROL, "public; max-age=604800")
|
||||
.body(cached.unwrap()))
|
||||
} else {
|
||||
let client = match data.pool.get().await {
|
||||
Ok(x) => x,
|
||||
Err(x) => {
|
||||
error!("{}", x.to_string());
|
||||
return Err(AutocompleteError::ServerError);
|
||||
}
|
||||
};
|
||||
let results = match db::get_tags(&client, &prefix).await {
|
||||
Ok(x) => x,
|
||||
Err(x) => {
|
||||
error!("{}", x.to_string());
|
||||
return Err(AutocompleteError::ServerError);
|
||||
}
|
||||
};
|
||||
let serialized = serde_json::to_string(&results).unwrap_or_else(|_| "[]".to_string());
|
||||
let serialized_copy = serialized.clone();
|
||||
data.cache.insert(prefix, serialized).await;
|
||||
Ok(HttpResponse::Ok()
|
||||
.set_header(header::CONTENT_TYPE, "application/json; charset=utf-8")
|
||||
.set_header(header::CACHE_CONTROL, "public, max-age=604800")
|
||||
.body(serialized_copy))
|
||||
};
|
||||
}
|
||||
|
||||
#[actix_web::main]
|
||||
async fn main() -> std::io::Result<()> {
|
||||
use actix_web::{App, HttpServer};
|
||||
use dotenv::dotenv;
|
||||
use tokio_postgres::NoTls;
|
||||
use moka::future::CacheBuilder;
|
||||
use std::time::Duration;
|
||||
dotenv().ok();
|
||||
env_logger::init();
|
||||
|
||||
let config = crate::config::Config::from_env().unwrap();
|
||||
let pool = config.pg.create_pool(NoTls).unwrap();
|
||||
let cache = CacheBuilder::new(15_000)
|
||||
.time_to_live(Duration::from_secs(6 * 60 * 60))
|
||||
.build();
|
||||
|
||||
HttpServer::new(move || {
|
||||
App::new()
|
||||
.data(AutocompleteState {
|
||||
pool: pool.clone(),
|
||||
cache: cache.clone(),
|
||||
})
|
||||
.service(autocomplete)
|
||||
}).bind(config.server_addr.clone())?
|
||||
.run().await
|
||||
}
|
Loading…
Reference in New Issue
Block a user