873 lines
25 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package repository
import (
"database/sql"
"encoding/json"
"fmt"
"strings"
"time"
sq "github.com/Masterminds/squirrel"
domain "github.com/haydenhargreaves/Potion/internal/domain/recipe"
"github.com/jmoiron/sqlx"
"github.com/lib/pq"
)
type RecipeRepository struct {
db *sqlx.DB
}
// Compile-time check to ensure the RecipeRepository implements domain.RecipeRepository
var _ domain.RecipeRepository = (*RecipeRepository)(nil)
// NewRecipeRepository creates a user repository object which is used by the user service to access
// the database. Any recipe related database operations will take place in this repository.
func NewRecipeRepository(db *sqlx.DB) domain.RecipeRepository {
return &RecipeRepository{db: db}
}
// NOTE: This function modified the provided recipe with the new values, such as id and time stamp
// CreateRecipe creates a recipe in the database. The recipe provided should contain all data except
// time stamps and the ID; the database will fill them when the operation succeeds. Any errors will
// be bubbled to the caller. The recipe parameter is passed by reference and will therefore be updated
// directly and the new fields (ID, created) can be accessed upon success.
func (r *RecipeRepository) CreateRecipe(recipe *domain.Recipe) error {
// Convert data into a readable format
durationJSON, err := json.Marshal(recipe.Duration)
if err != nil {
return err
}
ingredientsStore := domain.RecipeIngredientStore{
Sections: recipe.Sections,
Ingredients: recipe.Ingredients,
}
ingredientsJSON, err := json.Marshal(ingredientsStore)
if err != nil {
return err
}
instructions := make([]string, len(recipe.Instructions))
for i, instruction := range recipe.Instructions {
instructions[i] = instruction.Content
}
psql := sq.StatementBuilder.PlaceholderFormat(sq.Dollar)
query := psql.
Insert("recipes").
Columns(
"title",
"description",
"instructions",
"serves",
"difficulty",
"duration",
"category",
"ingredients",
"userid",
"modified",
"created",
).
Values(
recipe.Title,
recipe.Description,
pq.Array(instructions),
recipe.Serves,
recipe.Difficulty,
durationJSON,
string(recipe.Category),
ingredientsJSON,
recipe.UserId,
nil,
recipe.Created,
).
Suffix("RETURNING id")
_sql, args, err := query.ToSql()
if err != nil {
return fmt.Errorf("Failed to construct query: %w", err)
}
var id int
if err := r.db.Get(&id, _sql, args...); err != nil {
return fmt.Errorf("Failed to create recipe: %w", err)
}
// Set the new ID
recipe.Id = id
return nil
}
// EditRecipe updates a recipe in the database. The recipe provided must contain an ID, otherwise this
// function will fail - it will not know what recipe to edit.
func (r *RecipeRepository) EditRecipe(recipe *domain.Recipe, userId int) error {
if recipe.Id <= 0 {
return fmt.Errorf("Recipe must contain an ID. Cannot edit unknown recipe.")
}
durationJSON, err := json.Marshal(recipe.Duration)
if err != nil {
return err
}
ingredientsStore := domain.RecipeIngredientStore{
Sections: recipe.Sections,
Ingredients: recipe.Ingredients,
}
ingredientsJSON, err := json.Marshal(ingredientsStore)
if err != nil {
return err
}
instructions := make([]string, len(recipe.Instructions))
for i, instruction := range recipe.Instructions {
instructions[i] = instruction.Content
}
psql := sq.StatementBuilder.PlaceholderFormat(sq.Dollar)
query := psql.
Update("recipes").
Set("title", recipe.Title).
Set("description", recipe.Description).
Set("instructions", pq.Array(instructions)).
Set("serves", recipe.Serves).
Set("difficulty", recipe.Difficulty).
Set("duration", durationJSON).
Set("category", string(recipe.Category)).
Set("ingredients", ingredientsJSON).
Set("modified", time.Now().UTC()).
Where(sq.Eq{
"id": recipe.Id,
"userid": userId,
})
_sql, args, err := query.ToSql()
if err != nil {
return fmt.Errorf("Failed to construct query: %w", err)
}
result, err := r.db.Exec(_sql, args...)
if err != nil {
return fmt.Errorf("Failed to update recipe: %w", err)
}
if rows, err := result.RowsAffected(); err != nil {
return err
} else if rows != 1 {
return fmt.Errorf("Modified an unexpected number of rows. Expected 1, modified %d.", rows)
}
return nil
}
// DeleteRecipe deletes a recipe in the database. This is done by setting the deleted field to true.
// This will create a "soft delete" effect. This function does not validate that the user is the owner,
// so the caller should validate the owner. If any errors occur, they will be returned to the caller.
func (r *RecipeRepository) DeleteRecipe(recipeId, userId int) error {
psql := sq.StatementBuilder.PlaceholderFormat(sq.Dollar)
query := psql.
Update("recipes").
Set("deleted", true).
Set("modified", time.Now().UTC()).
Where(sq.Eq{
"id": recipeId,
"userid": userId,
"deleted": false,
})
sql, args, err := query.ToSql()
if err != nil {
return fmt.Errorf("Failed to build delete query: %w", err)
}
result, err := r.db.Exec(sql, args...)
if err != nil {
return fmt.Errorf("Failed to delete recipe: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("Failed to get rows affects: %w", err)
}
if rows != 1 {
return fmt.Errorf("Incorrect number of rows modified. Expected 1, received %d.", rows)
}
return nil
}
// GetRecipe gets a recipe from the database via its ID. The operation is wrapped in a transaction
// for added safety. The repository will not check for a nil result, instead the service will. Callers
// are responsible for protecting against double nil results. Any errors will be bubbled to the caller.
//
// This function will only return recipes that are not deleted. Any recipes marked deleted will be ignored
// and the standard "not-found" error will be returned.
func (r *RecipeRepository) GetRecipe(id int, userId *int) (*domain.Recipe, error) {
psql := sq.StatementBuilder.PlaceholderFormat(sq.Dollar)
query := psql.
Select(
"id",
"title",
"description",
"instructions",
"serves",
"difficulty",
"duration",
"category",
"ingredients",
"userid",
"modified",
"created",
"deleted",
).
From("recipes").
Where(sq.Eq{
"id": id,
"deleted": false,
})
_sql, args, err := query.ToSql()
if err != nil {
return nil, fmt.Errorf("Failed to construct sql query: %w", err)
}
var durationBytes []byte
var instructions pq.StringArray
var ingredientBytes []byte
var recipe domain.Recipe
if err := r.db.QueryRowx(_sql, args...).Scan(
&recipe.Id,
&recipe.Title,
&recipe.Description,
&instructions,
&recipe.Serves,
&recipe.Difficulty,
&durationBytes,
&recipe.Category,
&ingredientBytes,
&recipe.UserId,
&recipe.Modified,
&recipe.Created,
&recipe.Deleted,
); err != nil {
if err == sql.ErrNoRows {
return nil, err
}
return nil, fmt.Errorf("Failed to location recipe (id: %d) in database: %s", id, err.Error())
}
// Parse duration
if len(durationBytes) > 0 {
var duration domain.RecipeDuration
if err := json.Unmarshal(durationBytes, &duration); err != nil {
return nil, fmt.Errorf("Failed to parse duration from database: %s", err.Error())
}
recipe.Duration = duration
} else {
recipe.Duration = domain.RecipeDuration{}
}
// Parse ingredient
if len(ingredientBytes) > 0 {
var store domain.RecipeIngredientStore
if err := json.Unmarshal(ingredientBytes, &store); err != nil {
// Check for unmarshal to support backwards compatability
return nil, fmt.Errorf("Failed to parse ingredients from database: %s", err.Error())
}
recipe.Ingredients = store.Ingredients
recipe.Sections = store.Sections
} else {
recipe.Ingredients = []domain.RecipeIngredient{}
}
// Add instructions
for _, instruction := range instructions {
recipe.Instructions = append(recipe.Instructions, domain.RecipeInstruction{Content: instruction})
}
// Add tags
if err := r.GetRecipeTags(&recipe); err != nil {
fmt.Printf("ERROR getting recipe tags. %s\n", err.Error())
}
// Get favorite status, if user id is provided
if userId != nil {
if err := r.GetRecipeFavorite(&recipe, *userId); err != nil {
fmt.Printf("ERROR getting recipe favorite status. %s\n", err.Error())
}
} else {
recipe.Favorite = false
}
return &recipe, nil
}
// GetRecipes gets a list of recipes from the database via their ID. The operation is wrapped in a
// transaction for added safety. The repository will not check for a nil result, instead the service
// will. Callers are responsible for protecting against double nil results. Any errors will be bubbled
// to the caller.
//
// This function calls a function that only returns recipes that are not deleted. Any recipes marked
// deleted will be ignored and the standard "not-found" error will be returned.
func (r *RecipeRepository) GetRecipes(ids []int, userId *int) ([]domain.Recipe, error) {
var recipes []domain.Recipe
for _, id := range ids {
recipe, err := r.GetRecipe(id, userId)
if err != nil && err != sql.ErrNoRows {
return nil, err
}
// Skip any un-found recipes...?
if recipe != nil {
recipes = append(recipes, *recipe)
}
}
return recipes, nil
}
// isBitActive returns true when the bit at pos (0 indexed) is true.
func isBitActive(bits, pos int) bool {
return (bits>>pos)&1 == 1
}
// SearchRecipes will search the recipe table using the provided filters and return an unbound list
// of recipes. The filters are fairly complex, they are stored as bit masks. A more details
// description can be found in the recipe service implementation. Any errors will be bubbled to the
// caller.
//
// The favorites parameter is used to only return filters favorited by the userId provided.
//
// TODO: Pagination is required, to provide infinite scroll.
//
// 12/28/25: This function has changed, now longer returns the recipes, but their IDs for fetching
// elsewhere.
//
// 2/3/26: Refactored this large function to use Squirrel for simpler generation. Reduced line count by 50,
// but this is still insane. We need to clean this up.
//
// This function will only return recipes that are not deleted. Any recipes marked deleted will be ignored
// and the standard "not-found" error will be returned.
func (r *RecipeRepository) SearchRecipes(filters domain.SearchFilters, userId *int, favorites bool) ([]int, error) {
// Begin creating the query
psql := sq.StatementBuilder.PlaceholderFormat(sq.Dollar)
query := psql.Select("r.id").From("recipes r")
// Only select fields where the recipe ID can be found in the favorites table (mapped to user ID)
if favorites && userId != nil {
query = query.
Join("favorites f ON f.recipeId = r.id").
Where(sq.Eq{"f.userid": *userId})
}
// Compute and add meal type filters (7 bit options)
var mealCategories []string
for i := range 7 {
if isBitActive(filters.MealType, i) {
mealCategories = append(mealCategories, string(domain.ParseMeal(i)))
}
}
if len(mealCategories) > 0 {
query = query.Where(sq.Eq{"category": mealCategories})
}
// Compute and add time filters (5 bit options)
var timeOr sq.Or
for i := range 5 {
if isBitActive(filters.Time, i) {
switch i {
case 0:
timeOr = append(timeOr, sq.Lt{"(duration->>'total')::int": 15})
case 1:
timeOr = append(timeOr, sq.Expr("(duration->>'total')::int BETWEEN 15 AND 30"))
case 2:
timeOr = append(timeOr, sq.Expr("(duration->>'total')::int BETWEEN 30 AND 60"))
case 3:
timeOr = append(timeOr, sq.Expr("(duration->>'total')::int BETWEEN 60 AND 120"))
case 4:
timeOr = append(timeOr, sq.Gt{"(duration->>'total')::int": 120})
}
}
}
if len(timeOr) > 0 {
query = query.Where(timeOr)
}
// Compute and add difficulty filters (5 bit options)
var difficulties []int
for i := range 5 {
if isBitActive(filters.Difficulty, i) {
difficulties = append(difficulties, i+1)
}
}
if len(difficulties) > 0 {
query = query.Where(sq.Eq{"difficulty": difficulties})
}
// Compute and add serving size filters (5 bit options)
var servingOr sq.Or
for i := range 5 {
if isBitActive(filters.ServingSize, i) {
switch i {
case 0:
servingOr = append(servingOr, sq.Expr("serves BETWEEN 1 AND 2"))
case 1:
servingOr = append(servingOr, sq.Expr("serves BETWEEN 2 AND 4"))
case 2:
servingOr = append(servingOr, sq.Expr("serves BETWEEN 4 AND 6"))
case 3:
servingOr = append(servingOr, sq.Expr("serves BETWEEN 6 AND 8"))
case 4:
servingOr = append(servingOr, sq.Gt{"serves": 8})
}
}
}
if len(servingOr) > 0 {
query = query.Where(servingOr)
}
// Handle search with full-text search and ILIKE fallback
if filters.Search != "" {
spl := strings.Split(filters.Search, " ")
var cleaned []string
// Sanitize search terms
replacer := strings.NewReplacer(
"'", "",
"-", "",
"&", "",
"|", "",
"!", "",
":", "",
"(", "",
")", "",
)
for _, term := range spl {
q := strings.TrimSpace(replacer.Replace(term))
if q != "" {
cleaned = append(cleaned, q+":*") // Add prefix matching
}
}
if len(cleaned) > 0 {
vectorQuery := strings.Join(cleaned, " | ")
// Build search condition as raw SQL expression
// We'll use sq.Expr for the entire OR clause
var searchConditions []string
var searchArgs []interface{}
// Full-text search
searchConditions = append(searchConditions, "r.search_vector @@ to_tsquery('english', ?)")
searchArgs = append(searchArgs, vectorQuery)
// ILIKE fallback for substring matching
for _, term := range spl {
cleanTerm := strings.TrimSpace(replacer.Replace(term))
if cleanTerm != "" {
searchConditions = append(searchConditions, "r.title ILIKE ?")
searchArgs = append(searchArgs, "%"+cleanTerm+"%")
searchConditions = append(searchConditions, "r.description ILIKE ?")
searchArgs = append(searchArgs, "%"+cleanTerm+"%")
}
}
// Combine all conditions with OR
searchExpr := fmt.Sprintf("(%s)", strings.Join(searchConditions, " OR "))
query = query.Where(sq.Expr(searchExpr, searchArgs...))
// Add ordering for search results
query = query.
OrderBy(fmt.Sprintf("CASE WHEN r.search_vector @@ to_tsquery('english', '%s') THEN 1 ELSE 2 END", vectorQuery)).
OrderBy(fmt.Sprintf("ts_rank(r.search_vector, to_tsquery('english', '%s')) DESC", vectorQuery)).
OrderBy(fmt.Sprintf("ts_rank_cd(r.search_vector, to_tsquery('english', '%s')) DESC", vectorQuery))
}
}
// Exclude deleted recipes
query = query.Where(sq.Eq{"deleted": false})
sql, args, err := query.ToSql()
if err != nil {
return nil, fmt.Errorf("Failed to build query: %w", err)
}
// Execute query using SQLX
var ids []int
if err = r.db.Select(&ids, sql, args...); err != nil {
return nil, fmt.Errorf("Failed to query recipes: %w", err)
}
return ids, nil
}
// CreateRecipeTags accepts a list of tags (names) and a recipe (already created by the DB) and
// creates the tags that do not exists, and adds those that do exist to the mapping table for the
// recipe. The result is records in the RecipeTags mapping table that represent all of the new
// and existing tags provided to this function. The recipe object must only contain an ID to call
// this function successfully, therefore, it must be an existing recipe. Any errors will be bubbled
// to the caller.
func (r *RecipeRepository) CreateRecipeTags(recipe domain.Recipe, tags []string) error {
tx, err := r.db.Beginx()
if err != nil {
return err
}
defer tx.Rollback()
psql := sq.StatementBuilder.
PlaceholderFormat(sq.Dollar).
RunWith(tx)
// Normalize tags (lowercase, trimmed, no duplicates)
normalized := make(map[string]struct{})
for _, tag := range tags {
t := strings.ToLower(strings.TrimSpace(tag))
if t != "" {
normalized[t] = struct{}{}
}
}
// Insert tags and collect IDs
var tagIDs []int
for tag := range normalized {
var tagID int
_sql, args, err := psql.
Insert("tags").
Columns("name").
Values(tag).
Suffix("ON CONFLICT (name) DO UPDATE SET name = EXCLUDED.name RETURNING id").
ToSql()
if err != nil {
return fmt.Errorf("failed to build tag insert query: %w", err)
}
if err = tx.QueryRowx(_sql, args...).Scan(&tagID); err != nil {
return fmt.Errorf("failed to retrieve or create tag: %w", err)
}
tagIDs = append(tagIDs, tagID)
}
// Insert recipe <-> tag mappings
for _, tagID := range tagIDs {
_sql, args, err := psql.
Insert("RecipeTags").
Columns("RecipeId", "TagId").
Values(recipe.Id, tagID).
ToSql()
if err != nil {
return fmt.Errorf("failed to build recipe tag mapping query: %w", err)
}
if _, err = tx.Exec(_sql, args...); err != nil {
return fmt.Errorf("failed to insert recipe tag mapping: %w", err)
}
}
if err = tx.Commit(); err != nil {
return err
}
return nil
}
// UpdateRecipeTags replaces all existing tags for a recipe with a new list of tags.
// It removes all current tag associations, creates any new tags that don't exist,
// and creates new associations for the provided tags. The recipe object must contain
// a valid ID. Any errors will be bubbled to the caller.
func (r *RecipeRepository) UpdateRecipeTags(recipe domain.Recipe, tags []string) error {
if recipe.Id <= 0 {
return fmt.Errorf("[ERROR] Recipe must have a valid ID")
}
tx, err := r.db.Beginx()
if err != nil {
return err
}
defer tx.Rollback()
psql := sq.StatementBuilder.
PlaceholderFormat(sq.Dollar).
RunWith(tx)
// Step 1: delete existing tag mappings
{
_sql, args, err := psql.
Delete("RecipeTags").
Where(sq.Eq{"RecipeId": recipe.Id}).
ToSql()
if err != nil {
return fmt.Errorf("[ERROR] failed to build delete recipe tags query: %w", err)
}
if _, err = tx.Exec(_sql, args...); err != nil {
return fmt.Errorf("[ERROR] failed to delete existing recipe tags: %w", err)
}
}
// Step 2: normalize tags
normalized := make(map[string]struct{})
for _, tag := range tags {
t := strings.ToLower(strings.TrimSpace(tag))
if t != "" {
normalized[t] = struct{}{}
}
}
// No tags means "remove all tags" — were done
if len(normalized) == 0 {
return tx.Commit()
}
// Step 3: upsert tags and collect IDs
var tagIDs []int
for tag := range normalized {
var tagID int
_sql, args, err := psql.
Insert("tags").
Columns("name").
Values(tag).
Suffix("ON CONFLICT (name) DO UPDATE SET name = EXCLUDED.name RETURNING id").
ToSql()
if err != nil {
return fmt.Errorf("[ERROR] failed to build tag upsert query: %w", err)
}
if err = tx.QueryRowx(_sql, args...).Scan(&tagID); err != nil {
return fmt.Errorf("[ERROR] failed to retrieve or create tag: %w", err)
}
tagIDs = append(tagIDs, tagID)
}
// Step 4: insert new recipe ↔ tag mappings
for _, tagID := range tagIDs {
_sql, args, err := psql.
Insert("RecipeTags").
Columns("RecipeId", "TagId").
Values(recipe.Id, tagID).
ToSql()
if err != nil {
return fmt.Errorf("[ERROR] failed to build recipe tag mapping query: %w", err)
}
if _, err = tx.Exec(_sql, args...); err != nil {
return fmt.Errorf("[ERROR] failed to insert recipe tag mapping: %w", err)
}
}
return tx.Commit()
}
// GetUserRecipes gets a list of a users owned recipes. This function does not ensure the user is
// authenticated or exists. If nothing is found, a blank slice will be returned. The resulting list
// is sorted by the created dates, newest first. Any errors will be bubbled to the caller.
//
// 12/28/25: This now returns just the IDs, the service can handle fetching them.
//
// This function will only return recipes that are not deleted. Any recipes marked deleted will be ignored
// and the standard "not-found" error will be returned.
func (r *RecipeRepository) GetUserRecipesIds(userId int) ([]int, error) {
psql := sq.StatementBuilder.PlaceholderFormat(sq.Dollar)
query := psql.
Select("id").
From("recipes").
Where(sq.Eq{
"userid": userId,
"deleted": false,
}).
OrderBy("created DESC")
_sql, args, err := query.ToSql()
if err != nil {
return []int{}, fmt.Errorf("Failed to construct SQL query: %w", err)
}
var ids []int
if err := r.db.Select(&ids, _sql, args...); err != nil {
return []int{}, fmt.Errorf("Failed to get user recipes: %w", err)
}
return ids, nil
}
// GetUserRecipes gets a list of a users favorited recipes. This function does not ensure the user is
// authenticated or exists. If nothing is found, a blank slice will be returned. The resulting list
// is sorted by the created dates, newest first. Any errors will be bubbled to the caller.
//
// 12/28/25: This now just returns the IDs, so the service can handle the fetching.
//
// This function will only return recipes that are not deleted. Any recipes marked deleted will be ignored
// and the standard "not-found" error will be returned.
func (r *RecipeRepository) GetUserFavoriteRecipesIds(userId int) ([]int, error) {
psql := sq.StatementBuilder.PlaceholderFormat(sq.Dollar)
query := psql.
Select("r.id").
From("favorites f").
Join("recipes r on r.id = f.recipeid").
Where(sq.Eq{
"f.userid": userId,
"deleted": false,
}).
OrderBy("f.created DESC")
_sql, args, err := query.ToSql()
if err != nil {
return []int{}, fmt.Errorf("Failed to construct SQL query: %w", err)
}
fmt.Println(_sql)
var ids []int
if err := r.db.Select(&ids, _sql, args...); err != nil {
return []int{}, fmt.Errorf("Failed to get users' favorite recipes: %w", err)
}
return ids, nil
}
// GetRecipeTags requires a recipe to be filled with at least an ID. This function will use the ID
// defined in the provided recipe to fill the Tags array with the recipe's tags from the database.
// The recipe is modified in place and is not returned. If the recipe is nil, the function will
// return nothing (skipping). Any errors will be bubbled to the caller.
func (r *RecipeRepository) GetRecipeTags(recipe *domain.Recipe) error {
if recipe == nil {
return nil
}
psql := sq.StatementBuilder.PlaceholderFormat(sq.Dollar)
query := psql.
Select("t.*").
From("tags t").
Join("recipetags rt on rt.tagid = t.id").
Where(sq.Eq{"rt.recipeid": recipe.Id})
_sql, args, err := query.ToSql()
if err != nil {
return fmt.Errorf("Failed to construct sql query: %w", err)
}
recipe.Tags = []domain.Tag{}
if err := r.db.Select(&recipe.Tags, _sql, args...); err != nil {
return fmt.Errorf("Failed to get recipe tags: %w", err)
}
return nil
}
// GetRecipeFavorite requires a recipe to be filled with at least an ID. This function will use the
// ID defined in the provided recipe to fill the favorite status of the recipe, based on the provided
// userId. The recipe is modified in place and is not returned. If the recipe is nil, the function
// will return nothing (skipping). Any errors will be bubbled to the caller.
func (r *RecipeRepository) GetRecipeFavorite(recipe *domain.Recipe, userId int) error {
if recipe == nil {
return nil
}
psql := sq.StatementBuilder.PlaceholderFormat(sq.Dollar)
query := psql.
Select("COUNT(*)").
From("favorites").
Where(sq.Eq{
"recipeid": recipe.Id,
"userid": userId,
})
_sql, args, err := query.ToSql()
if err != nil {
return fmt.Errorf("Failed to construct SQL query: %w", err)
}
var count int
if err := r.db.Get(&count, _sql, args...); err != nil {
return fmt.Errorf("Failed to get recipe favorite status: %w", err)
}
recipe.Favorite = count > 0
return nil
}
// GetRecipeOfTheWeekId searches for the most recent recipe of the week. If there is not a value,
// the recipe will be nil. This function simply collects the most recent entry in the recipeoftheweek
// table and return it. If there is no entry, nil will be returned. Any errors will be bubbled to
// the caller. All that is returned is the recipe ID, that way the caller can handle the fetching.
func (r *RecipeRepository) GetRecipeOfTheWeekId(userId *int) (*int, error) {
psql := sq.StatementBuilder.PlaceholderFormat(sq.Dollar)
query := psql.
Select("r.id").
From("recipes r").
Join("recipeoftheweek rw ON rw.recipeid = r.id").
Where(sq.Eq{"r.deleted": false}).
OrderBy("rw.created DESC").
Limit(1)
_sql, args, err := query.ToSql()
if err != nil {
return nil, fmt.Errorf("Failed to build SQL query: %w", err)
}
var recipeId int
if err := r.db.Get(&recipeId, _sql, args...); err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, fmt.Errorf("Failed to locate recipe in database: %s", err.Error())
}
return &recipeId, nil
}
// IsRecipeOwner takes two required arguments: a user id and a recipe id. This function queries the DB
// to check if the user is the owner of the provided recipe. Any error will be bubbled to the caller.
func (r *RecipeRepository) IsRecipeOwner(userId, recipeId int) (bool, error) {
psql := sq.StatementBuilder.PlaceholderFormat(sq.Dollar)
query := psql.
Select("userid").
From("recipes").
Where(sq.Eq{
"id": recipeId,
"deleted": false,
})
_sql, args, err := query.ToSql()
if err != nil {
return false, err
}
var recipeOwnerId int
if err := r.db.Get(&recipeOwnerId, _sql, args...); err != nil {
if err == sql.ErrNoRows {
return false, nil
}
return false, fmt.Errorf("Failed to get recipe owner id: %w", err)
}
return recipeOwnerId == userId, nil
}