diff --git a/internal/infrastructure/database/repository/recipe_repository.go b/internal/infrastructure/database/repository/recipe_repository.go index e346a80..7a6f46d 100644 --- a/internal/infrastructure/database/repository/recipe_repository.go +++ b/internal/infrastructure/database/repository/recipe_repository.go @@ -8,6 +8,7 @@ import ( "strings" "time" + sq "github.com/Masterminds/squirrel" domain "github.com/haydenhargreaves/Potion/internal/domain/recipe" "github.com/jmoiron/sqlx" "github.com/lib/pq" @@ -339,201 +340,364 @@ func isBitActive(bits, pos int) bool { // 12/28/25: This function has changed, now longer returns the recipes, but their IDs for fetching // elsewhere. // +// 2/3/26: Refactored this large function to use Squirrel for simpler generation. +// // This function will only return recipes that are not deleted. Any recipes marked deleted will be ignored // and the standard "not-found" error will be returned. func (r *RecipeRepository) SearchRecipes(filters domain.SearchFilters, userId *int, favorites bool) ([]int, error) { - // Compute meals type filters (there are 7 bits) - var mealConditions []string + psql := sq.StatementBuilder.PlaceholderFormat(sq.Dollar) + + query := psql.Select("r.id").From("recipes r") + + // Only select fields where the recipe ID can be found in the favorites table (mapped to user ID) + if favorites && userId != nil { + query = query. + Join("favorites f ON f.recipeId = r.id"). + Where(sq.Eq{"f.userid": *userId}) + } + + // Compute and add meal type filters (7 bit options) + var mealCategories []string for i := range 7 { if isBitActive(filters.MealType, i) { - mealConditions = append(mealConditions, fmt.Sprintf("category = '%s'", domain.ParseMeal(i))) + mealCategories = append(mealCategories, string(domain.ParseMeal(i))) } } - // Compute time filters (there are 5 bits) - var timeConditions []string + if len(mealCategories) > 0 { + query = query.Where(sq.Eq{"category": mealCategories}) + } + + // Compute and add time filters (5 bit options) + var timeOr sq.Or for i := range 5 { - var cond string if isBitActive(filters.Time, i) { switch i { case 0: - cond = "(duration->>'total')::int < 15" + timeOr = append(timeOr, sq.Lt{"(duration->>'total')::int": 15}) case 1: - cond = "(duration->>'total')::int BETWEEN 15 AND 30" + timeOr = append(timeOr, sq.Expr("(duration->>'total')::int BETWEEN 15 AND 30")) case 2: - cond = "(duration->>'total')::int BETWEEN 30 AND 60" + timeOr = append(timeOr, sq.Expr("(duration->>'total')::int BETWEEN 30 AND 60")) case 3: - cond = "(duration->>'total')::int BETWEEN 60 AND 120" + timeOr = append(timeOr, sq.Expr("(duration->>'total')::int BETWEEN 60 AND 120")) case 4: - cond = "(duration->>'total')::int > 120" + timeOr = append(timeOr, sq.Gt{"(duration->>'total')::int": 120}) } - timeConditions = append(timeConditions, cond) } } - // Compute difficulty filters (there are 5 bits) - var difficultyConditions []string + if len(timeOr) > 0 { + query = query.Where(timeOr) + } + + // Compute and add difficulty filters (5 bit options) + var difficulties []int for i := range 5 { if isBitActive(filters.Difficulty, i) { - cond := fmt.Sprintf("difficulty = '%d'", i+1) - difficultyConditions = append(difficultyConditions, cond) + difficulties = append(difficulties, i+1) } } - // Compute serving size filters (there are 5 bits) - var servingConditions []string + if len(difficulties) > 0 { + query = query.Where(sq.Eq{"difficulty": difficulties}) + } + + // Compute and add serving size filters (5 bit options) + var servingOr sq.Or for i := range 5 { - var cond string if isBitActive(filters.ServingSize, i) { switch i { case 0: - cond = "serves BETWEEN 1 AND 2" + servingOr = append(servingOr, sq.Expr("serves BETWEEN 1 AND 2")) case 1: - cond = "serves BETWEEN 2 AND 4" + servingOr = append(servingOr, sq.Expr("serves BETWEEN 2 AND 4")) case 2: - cond = "serves BETWEEN 4 AND 6" + servingOr = append(servingOr, sq.Expr("serves BETWEEN 4 AND 6")) case 3: - cond = "serves BETWEEN 6 AND 8" + servingOr = append(servingOr, sq.Expr("serves BETWEEN 6 AND 8")) case 4: - cond = "serves > 8" + servingOr = append(servingOr, sq.Gt{"serves": 8}) } - servingConditions = append(servingConditions, cond) } } - // Merge condition strings - mealString := fmt.Sprintf("(%s)", strings.Join(mealConditions, " OR ")) - timeString := fmt.Sprintf("(%s)", strings.Join(timeConditions, " OR ")) - difficultyString := fmt.Sprintf("(%s)", strings.Join(difficultyConditions, " OR ")) - servingString := fmt.Sprintf("(%s)", strings.Join(servingConditions, " OR ")) - - // Combine condition strings - var conditions []string - if len(mealConditions) > 0 { - conditions = append(conditions, mealString) - } - if len(timeConditions) > 0 { - conditions = append(conditions, timeString) - } - if len(difficultyConditions) > 0 { - conditions = append(conditions, difficultyString) - } - if len(servingConditions) > 0 { - conditions = append(conditions, servingString) + if len(servingOr) > 0 { + query = query.Where(servingOr) } - // Define columns to select - columns := []string{ - "r.id", - } - - // Create search vector query with SAFE parameterization - var orderBy string = "" - var searchQuery string = "" - + // Handle search with full-text search and ILIKE fallback if filters.Search != "" { spl := strings.Split(filters.Search, " ") var cleaned []string - // Use a string replacer for safety + // Sanitize search terms replacer := strings.NewReplacer( "'", "", "-", "", "&", "", "|", "", "!", "", - ":", "", // Remove colons to prevent tsquery syntax injection + ":", "", "(", "", ")", "", ) - for i := range len(spl) { - q := strings.TrimSpace(replacer.Replace(spl[i])) - if q != "" { - // Add :* suffix for prefix matching - cleaned = append(cleaned, q+":*") - } - } - - // Join with OR operator for full-text search - vector_query := strings.Join(cleaned, " | ") - searchQuery = vector_query - - // Full-text search with prefix matching - searchCondition := fmt.Sprintf("r.search_vector @@ to_tsquery('english', '%s')", vector_query) - - // Add fallback ILIKE for true substring matching - // This catches cases where "pan" is inside "pancake" but not at word boundaries - var ilikeConditions []string for _, term := range spl { - cleanTerm := strings.TrimSpace(replacer.Replace(term)) - if cleanTerm != "" { - ilikeConditions = append(ilikeConditions, fmt.Sprintf("(r.title ILIKE '%%%s%%' OR r.description ILIKE '%%%s%%')", cleanTerm, cleanTerm)) + q := strings.TrimSpace(replacer.Replace(term)) + if q != "" { + cleaned = append(cleaned, q+":*") // Add prefix matching } } - if len(ilikeConditions) > 0 { - searchCondition = fmt.Sprintf("(%s OR %s)", searchCondition, strings.Join(ilikeConditions, " OR ")) + if len(cleaned) > 0 { + vectorQuery := strings.Join(cleaned, " | ") + + // Build search condition as raw SQL expression + // We'll use sq.Expr for the entire OR clause + var searchConditions []string + var searchArgs []interface{} + + // Full-text search + searchConditions = append(searchConditions, "r.search_vector @@ to_tsquery('english', ?)") + searchArgs = append(searchArgs, vectorQuery) + + // ILIKE fallback for substring matching + for _, term := range spl { + cleanTerm := strings.TrimSpace(replacer.Replace(term)) + if cleanTerm != "" { + searchConditions = append(searchConditions, "r.title ILIKE ?") + searchArgs = append(searchArgs, "%"+cleanTerm+"%") + + searchConditions = append(searchConditions, "r.description ILIKE ?") + searchArgs = append(searchArgs, "%"+cleanTerm+"%") + } + } + + // Combine all conditions with OR + searchExpr := fmt.Sprintf("(%s)", strings.Join(searchConditions, " OR ")) + query = query.Where(sq.Expr(searchExpr, searchArgs...)) + + // Add ordering for search results + query = query. + OrderBy(fmt.Sprintf("CASE WHEN r.search_vector @@ to_tsquery('english', '%s') THEN 1 ELSE 2 END", vectorQuery)). + OrderBy(fmt.Sprintf("ts_rank(r.search_vector, to_tsquery('english', '%s')) DESC", vectorQuery)). + OrderBy(fmt.Sprintf("ts_rank_cd(r.search_vector, to_tsquery('english', '%s')) DESC", vectorQuery)) } - - conditions = append(conditions, searchCondition) - - // Ranking with preference for full-text matches - orderBy = fmt.Sprintf(` - ORDER BY - CASE - WHEN r.search_vector @@ to_tsquery('english', '%s') THEN 1 - ELSE 2 - END, - ts_rank(r.search_vector, to_tsquery('english', '%s')) DESC, - ts_rank_cd(r.search_vector, to_tsquery('english', '%s')) DESC - `, searchQuery, searchQuery, searchQuery) } - // Generate the query - var query string - if favorites && userId != nil { - query = fmt.Sprintf( - "SELECT %s FROM recipes r JOIN favorites f ON f.recipeId = r.id", - strings.Join(columns, ","), - ) - conditions = append(conditions, fmt.Sprintf("f.userid = %d", *userId)) - } else { - query = fmt.Sprintf("SELECT %s FROM recipes r", strings.Join(columns, ",")) - } + // Exclude deleted recipes + query = query.Where(sq.Eq{"deleted": false}) - // Convert and append conditions if provided - conditions = append(conditions, "deleted = false") - if len(conditions) > 0 { - conditionsString := fmt.Sprintf("WHERE %s", strings.Join(conditions, " AND ")) - query = fmt.Sprintf("%s %s", query, conditionsString) - } - - // Append sorting order if exists - if len(orderBy) > 0 { - query = fmt.Sprintf("%s %s", query, orderBy) - } - - // Finish it off with a semicolon! - query += ";" - - // Execute the query - rows, err := r.db.Query(query) + sql, args, err := query.ToSql() if err != nil { - return []int{}, fmt.Errorf("failed to query recipes: %w", err) + return nil, fmt.Errorf("[ERROR] Failed to build query: %w", err) } - defer rows.Close() + fmt.Println(sql) + fmt.Println(args) + + // Execute query using SQLX var ids []int - for rows.Next() { - var id int - if err := rows.Scan(&id); err != nil { - return []int{}, fmt.Errorf("failed to extract ID: %s\n", err.Error()) - } - ids = append(ids, id) + if err = r.db.Select(&ids, sql, args...); err != nil { + return nil, fmt.Errorf("[ERROR] Failed to query recipes: %w", err) } return ids, nil + + // LEGACY CODE + // Compute meals type filters (there are 7 bits) + // var mealConditions []string + // for i := range 7 { + // if isBitActive(filters.MealType, i) { + // mealConditions = append(mealConditions, fmt.Sprintf("category = '%s'", domain.ParseMeal(i))) + // } + // } + // + // // Compute time filters (there are 5 bits) + // var timeConditions []string + // for i := range 5 { + // var cond string + // if isBitActive(filters.Time, i) { + // switch i { + // case 0: + // cond = "(duration->>'total')::int < 15" + // case 1: + // cond = "(duration->>'total')::int BETWEEN 15 AND 30" + // case 2: + // cond = "(duration->>'total')::int BETWEEN 30 AND 60" + // case 3: + // cond = "(duration->>'total')::int BETWEEN 60 AND 120" + // case 4: + // cond = "(duration->>'total')::int > 120" + // } + // timeConditions = append(timeConditions, cond) + // } + // } + // + // // Compute difficulty filters (there are 5 bits) + // var difficultyConditions []string + // for i := range 5 { + // if isBitActive(filters.Difficulty, i) { + // cond := fmt.Sprintf("difficulty = '%d'", i+1) + // difficultyConditions = append(difficultyConditions, cond) + // } + // } + // + // // Compute serving size filters (there are 5 bits) + // var servingConditions []string + // for i := range 5 { + // var cond string + // if isBitActive(filters.ServingSize, i) { + // switch i { + // case 0: + // cond = "serves BETWEEN 1 AND 2" + // case 1: + // cond = "serves BETWEEN 2 AND 4" + // case 2: + // cond = "serves BETWEEN 4 AND 6" + // case 3: + // cond = "serves BETWEEN 6 AND 8" + // case 4: + // cond = "serves > 8" + // } + // servingConditions = append(servingConditions, cond) + // } + // } + // + // // Merge condition strings + // mealString := fmt.Sprintf("(%s)", strings.Join(mealConditions, " OR ")) + // timeString := fmt.Sprintf("(%s)", strings.Join(timeConditions, " OR ")) + // difficultyString := fmt.Sprintf("(%s)", strings.Join(difficultyConditions, " OR ")) + // servingString := fmt.Sprintf("(%s)", strings.Join(servingConditions, " OR ")) + // + // // Combine condition strings + // var conditions []string + // if len(mealConditions) > 0 { + // conditions = append(conditions, mealString) + // } + // if len(timeConditions) > 0 { + // conditions = append(conditions, timeString) + // } + // if len(difficultyConditions) > 0 { + // conditions = append(conditions, difficultyString) + // } + // if len(servingConditions) > 0 { + // conditions = append(conditions, servingString) + // } + // + // // Define columns to select + // columns := []string{ + // "r.id", + // } + // + // // Create search vector query with SAFE parameterization + // var orderBy string = "" + // var searchQuery string = "" + // + // if filters.Search != "" { + // spl := strings.Split(filters.Search, " ") + // var cleaned []string + // + // // Use a string replacer for safety + // replacer := strings.NewReplacer( + // "'", "", + // "-", "", + // "&", "", + // "|", "", + // "!", "", + // ":", "", // Remove colons to prevent tsquery syntax injection + // "(", "", + // ")", "", + // ) + // + // for i := range len(spl) { + // q := strings.TrimSpace(replacer.Replace(spl[i])) + // if q != "" { + // // Add :* suffix for prefix matching + // cleaned = append(cleaned, q+":*") + // } + // } + // + // // Join with OR operator for full-text search + // vector_query := strings.Join(cleaned, " | ") + // searchQuery = vector_query + // + // // Full-text search with prefix matching + // searchCondition := fmt.Sprintf("r.search_vector @@ to_tsquery('english', '%s')", vector_query) + // + // // Add fallback ILIKE for true substring matching + // // This catches cases where "pan" is inside "pancake" but not at word boundaries + // var ilikeConditions []string + // for _, term := range spl { + // cleanTerm := strings.TrimSpace(replacer.Replace(term)) + // if cleanTerm != "" { + // ilikeConditions = append(ilikeConditions, fmt.Sprintf("(r.title ILIKE '%%%s%%' OR r.description ILIKE '%%%s%%')", cleanTerm, cleanTerm)) + // } + // } + // + // if len(ilikeConditions) > 0 { + // searchCondition = fmt.Sprintf("(%s OR %s)", searchCondition, strings.Join(ilikeConditions, " OR ")) + // } + // + // conditions = append(conditions, searchCondition) + // + // // Ranking with preference for full-text matches + // orderBy = fmt.Sprintf(` + // ORDER BY + // CASE + // WHEN r.search_vector @@ to_tsquery('english', '%s') THEN 1 + // ELSE 2 + // END, + // ts_rank(r.search_vector, to_tsquery('english', '%s')) DESC, + // ts_rank_cd(r.search_vector, to_tsquery('english', '%s')) DESC + // `, searchQuery, searchQuery, searchQuery) + // } + // + // // Generate the query + // var query string + // if favorites && userId != nil { + // query = fmt.Sprintf( + // "SELECT %s FROM recipes r JOIN favorites f ON f.recipeId = r.id", + // strings.Join(columns, ","), + // ) + // conditions = append(conditions, fmt.Sprintf("f.userid = %d", *userId)) + // } else { + // query = fmt.Sprintf("SELECT %s FROM recipes r", strings.Join(columns, ",")) + // } + // + // // Convert and append conditions if provided + // conditions = append(conditions, "deleted = false") + // if len(conditions) > 0 { + // conditionsString := fmt.Sprintf("WHERE %s", strings.Join(conditions, " AND ")) + // query = fmt.Sprintf("%s %s", query, conditionsString) + // } + // + // // Append sorting order if exists + // if len(orderBy) > 0 { + // query = fmt.Sprintf("%s %s", query, orderBy) + // } + // + // // Finish it off with a semicolon! + // query += ";" + // + // // Execute the query + // rows, err := r.db.Query(query) + // if err != nil { + // return []int{}, fmt.Errorf("failed to query recipes: %w", err) + // } + // defer rows.Close() + // + // var ids []int + // for rows.Next() { + // var id int + // if err := rows.Scan(&id); err != nil { + // return []int{}, fmt.Errorf("failed to extract ID: %s\n", err.Error()) + // } + // ids = append(ids, id) + // } + // + // return ids, nil } // CreateRecipeTags accepts a list of tags (names) and a recipe (already created by the DB) and