diff --git a/internal/infrastructure/database/repository/recipe_repository.go b/internal/infrastructure/database/repository/recipe_repository.go index f229284..856acaf 100644 --- a/internal/infrastructure/database/repository/recipe_repository.go +++ b/internal/infrastructure/database/repository/recipe_repository.go @@ -179,7 +179,7 @@ func (r *RecipeRepository) EditRecipe(recipe *domain.Recipe, userId int) error { tx.Rollback() return err } - + if rows != 1 { tx.Rollback() return fmt.Errorf("[ERROR] Modified an unexpected number of rows. Expected 1, modified %d.", rows) @@ -421,52 +421,72 @@ func (r *RecipeRepository) SearchRecipes(filters domain.SearchFilters, userId *i conditions = append(conditions, servingString) } - // Define columns to select. More fields can be added if the full text search is required + // Define columns to select columns := []string{ "r.id", } - // TODO: Need to add these to the query - - // FROM ... JOIN favorites f ON f.recipeId = r.id - // WHERE ... AND f.userId = 3 - - // Create search vector query + // Create search vector query with SAFE parameterization var orderBy string = "" + var searchQuery string = "" + if filters.Search != "" { spl := strings.Split(filters.Search, " ") var cleaned []string - // Use a string replacer, each word in the query will be passed through this + // Use a string replacer for safety replacer := strings.NewReplacer( "'", "", "-", "", "&", "", "|", "", "!", "", + ":", "", // Remove colons to prevent tsquery syntax injection + "(", "", + ")", "", ) for i := range len(spl) { q := strings.TrimSpace(replacer.Replace(spl[i])) - if q != "" { - cleaned = append(cleaned, q) + // Add :* suffix for prefix matching + cleaned = append(cleaned, q+":*") } } + // Join with OR operator for full-text search vector_query := strings.Join(cleaned, " | ") + searchQuery = vector_query - conditions = append( - conditions, - fmt.Sprintf("r.search_vector @@ to_tsquery('english', '%s')", vector_query), - ) + // Full-text search with prefix matching + searchCondition := fmt.Sprintf("r.search_vector @@ to_tsquery('english', '%s')", vector_query) - template := ` + // Add fallback ILIKE for true substring matching + // This catches cases where "pan" is inside "pancake" but not at word boundaries + var ilikeConditions []string + for _, term := range spl { + cleanTerm := strings.TrimSpace(replacer.Replace(term)) + if cleanTerm != "" { + ilikeConditions = append(ilikeConditions, fmt.Sprintf("(r.title ILIKE '%%%s%%' OR r.description ILIKE '%%%s%%')", cleanTerm, cleanTerm)) + } + } + + if len(ilikeConditions) > 0 { + searchCondition = fmt.Sprintf("(%s OR %s)", searchCondition, strings.Join(ilikeConditions, " OR ")) + } + + conditions = append(conditions, searchCondition) + + // Ranking with preference for full-text matches + orderBy = fmt.Sprintf(` ORDER BY + CASE + WHEN r.search_vector @@ to_tsquery('english', '%s') THEN 1 + ELSE 2 + END, ts_rank(r.search_vector, to_tsquery('english', '%s')) DESC, ts_rank_cd(r.search_vector, to_tsquery('english', '%s')) DESC - ` - orderBy = fmt.Sprintf(template, vector_query, vector_query) + `, searchQuery, searchQuery, searchQuery) } // Generate the query @@ -476,8 +496,6 @@ func (r *RecipeRepository) SearchRecipes(filters domain.SearchFilters, userId *i "SELECT %s FROM recipes r JOIN favorites f ON f.recipeId = r.id", strings.Join(columns, ","), ) - - // Add new favorite condition to the conditions list conditions = append(conditions, fmt.Sprintf("f.userid = %d", *userId)) } else { query = fmt.Sprintf("SELECT %s FROM recipes r", strings.Join(columns, ",")) @@ -495,7 +513,7 @@ func (r *RecipeRepository) SearchRecipes(filters domain.SearchFilters, userId *i query = fmt.Sprintf("%s %s", query, orderBy) } - // Finish it off with a colon! + // Finish it off with a semicolon! query += ";" // Execute the query @@ -511,7 +529,6 @@ func (r *RecipeRepository) SearchRecipes(filters domain.SearchFilters, userId *i if err := rows.Scan(&id); err != nil { return []int{}, fmt.Errorf("failed to extract ID: %s\n", err.Error()) } - ids = append(ids, id) } @@ -586,17 +603,17 @@ func (r *RecipeRepository) UpdateRecipeTags(recipe domain.Recipe, tags []string) return err } defer tx.Rollback() // Rollback if we don't commit - + if recipe.Id <= 0 { return fmt.Errorf("[ERROR] Recipe must have a valid ID") } - + // Step 1: Delete all existing tag associations for this recipe deleteQuery := `DELETE FROM RecipeTags WHERE RecipeId = $1;` if _, err := tx.Exec(deleteQuery, recipe.Id); err != nil { return fmt.Errorf("[ERROR] Failed to delete existing recipe tags: %w", err) } - + // Step 2: Normalize the tag names (lower case with trimmed space) normalized := make(map[string]struct{}) // Use map to disallow duplicates for _, tag := range tags { @@ -605,7 +622,7 @@ func (r *RecipeRepository) UpdateRecipeTags(recipe domain.Recipe, tags []string) normalized[trimmed] = struct{}{} } } - + // If no tags provided, we're done (all tags removed) if len(normalized) == 0 { if err := tx.Commit(); err != nil { @@ -613,7 +630,7 @@ func (r *RecipeRepository) UpdateRecipeTags(recipe domain.Recipe, tags []string) } return nil } - + // Step 3: Insert the tags into the DB and return their IDs into the tag ID list var tagIds []int for tag := range normalized { @@ -629,7 +646,7 @@ func (r *RecipeRepository) UpdateRecipeTags(recipe domain.Recipe, tags []string) } tagIds = append(tagIds, tagId) } - + // Step 4: Insert the new tag associations // Use a single prepared statement for all inserts stmt, err := tx.Prepare("INSERT INTO RecipeTags (RecipeId, TagId) VALUES ($1, $2);") @@ -637,18 +654,18 @@ func (r *RecipeRepository) UpdateRecipeTags(recipe domain.Recipe, tags []string) return fmt.Errorf("[ERROR] Failed to create statement for recipe tag mapping: %w", err) } defer stmt.Close() - + for _, id := range tagIds { if _, err := stmt.Exec(recipe.Id, id); err != nil { return fmt.Errorf("[ERROR] Failed to insert tag-recipe mapping: %w", err) } } - + // Commit the transaction if err := tx.Commit(); err != nil { return err } - + return nil }