diff options
Diffstat (limited to 'lib')
-rw-r--r-- | lib/pleroma/user.ex | 1 | ||||
-rw-r--r-- | lib/pleroma/user/search.ex | 101 |
2 files changed, 94 insertions, 8 deletions
diff --git a/lib/pleroma/user.ex b/lib/pleroma/user.ex index c5c74d132..f0193207c 100644 --- a/lib/pleroma/user.ex +++ b/lib/pleroma/user.ex @@ -92,6 +92,7 @@ defmodule Pleroma.User do field(:local, :boolean, default: true) field(:follower_address, :string) field(:following_address, :string) + field(:levenshtein_distance, :integer, virtual: true) field(:search_rank, :float, virtual: true) field(:search_type, :integer, virtual: true) field(:tags, {:array, :string}, default: []) diff --git a/lib/pleroma/user/search.ex b/lib/pleroma/user/search.ex index cec59c372..3878e81bf 100644 --- a/lib/pleroma/user/search.ex +++ b/lib/pleroma/user/search.ex @@ -8,6 +8,8 @@ defmodule Pleroma.User.Search do import Ecto.Query @limit 20 + @levenshtein_max_query_length 5 + @search_rank_threshold 0 def search(query_string, opts \\ []) do resolve = Keyword.get(opts, :resolve, false) @@ -31,7 +33,10 @@ defmodule Pleroma.User.Search do defp format_query(query_string) do # Strip the beginning @ off if there is a query - query_string = String.trim_leading(query_string, "@") + query_string = + query_string + |> String.trim() + |> String.trim_leading("@") with [name, domain] <- String.split(query_string, "@") do encoded_domain = @@ -47,15 +52,33 @@ defmodule Pleroma.User.Search do end end + defp levenshtein_applicable?(query_string) do + String.length(query_string) <= @levenshtein_max_query_length + end + defp search_query(query_string, for_user, following) do - for_user - |> base_query(following) - |> filter_blocked_user(for_user) - |> filter_invisible_users() - |> filter_blocked_domains(for_user) - |> fts_search(query_string) - |> trigram_rank(query_string) + query = + for_user + |> base_query(following) + |> filter_blocked_user(for_user) + |> filter_invisible_users() + |> filter_blocked_domains(for_user) + + query = + if levenshtein_applicable?(query_string) do + query + |> levenshtein_distance(query_string) + |> fts_levenshtein_search(query_string) + |> trigram_levenshtein_rank(query_string) + else + query + |> fts_search(query_string) + |> trigram_rank(query_string) + end + + query |> boost_search_rank(for_user) + |> filter_by_search_rank() |> subquery() |> order_by(desc: :search_rank) |> maybe_restrict_local(for_user) @@ -78,6 +101,25 @@ defmodule Pleroma.User.Search do ) end + defp fts_levenshtein_search(query, query_string) do + tsquery = to_tsquery(query_string) + + from( + u in subquery(query), + where: + fragment( + """ + ? <= 2 OR + (to_tsvector('simple', ?) || to_tsvector('simple', ?)) @@ to_tsquery('simple', ?) + """, + u.levenshtein_distance, + u.name, + u.nickname, + ^tsquery + ) + ) + end + defp to_tsquery(query_string) do String.trim_trailing(query_string, "@" <> local_domain()) |> String.replace(~r/[!-\/|@|[-`|{-~|:-?]+/, " ") @@ -87,6 +129,45 @@ defmodule Pleroma.User.Search do |> Enum.join(" | ") end + # Trigram-based rank with bonus for close Levenshtein distance b/w query and nickname + defp trigram_levenshtein_rank(query, query_string) do + from( + u in subquery(query), + select_merge: %{ + search_rank: + fragment( + "similarity(?, trim(? || ' ' || coalesce(?, ''))) + \ + (CASE WHEN ? = 0 THEN 1.0 \ + WHEN ? = 1 AND length(?) > 1 THEN 0.5 + WHEN ? = 2 AND length(?) > 3 THEN 0.2 + ELSE 0 END)", + ^query_string, + u.nickname, + u.name, + u.levenshtein_distance, + u.levenshtein_distance, + ^query_string, + u.levenshtein_distance, + ^query_string + ) + } + ) + end + + defp levenshtein_distance(query, query_string) do + from( + u in query, + select_merge: %{ + levenshtein_distance: + fragment( + "levenshtein(?, regexp_replace(?, '@.+', ''))", + ^query_string, + u.nickname + ) + } + ) + end + defp trigram_rank(query, query_string) do from( u in query, @@ -185,4 +266,8 @@ defmodule Pleroma.User.Search do end defp boost_search_rank(query, _for_user), do: query + + defp filter_by_search_rank(query) do + from(u in subquery(query), where: u.search_rank > @search_rank_threshold) + end end |