diff --git a/server/dao/tag.go b/server/dao/tag.go index 398e6e8..5227c74 100644 --- a/server/dao/tag.go +++ b/server/dao/tag.go @@ -19,7 +19,7 @@ func CreateTag(tag string) (model.Tag, error) { if err := db.Create(&t).Error; err != nil { return model.Tag{}, err } - return t, nil + return GetTagByID(t.ID) } func CreateTagWithType(tag string, tagType string) (model.Tag, error) { @@ -82,13 +82,34 @@ func GetTagByName(name string) (model.Tag, error) { } func SetTagInfo(id uint, description string, aliasOf *uint, tagType string) error { + // Get the tag information old, err := GetTagByID(id) if err != nil { return err } - if aliasOf != nil && len(old.Aliases) > 0 { - return model.NewRequestError("Tag already has aliases, cannot set alias_of") + + // If the alias tag is an alias itself, we need to find its root tag + if aliasOf != nil { + tag, err := GetTagByID(*aliasOf) + if err != nil { + return err + } + if tag.AliasOf != nil { + aliasOf = tag.AliasOf + } } + + // If the tag has aliases, we need to update their alias_of field + if aliasOf != nil && len(old.Aliases) > 0 { + for _, alias := range old.Aliases { + err := db.Model(&alias).Update("alias_of", *aliasOf).Error + if err != nil { + return err + } + } + } + + // Update the tag information t := model.Tag{Model: gorm.Model{ ID: id, }, Description: description, Type: tagType, AliasOf: aliasOf} @@ -113,24 +134,51 @@ func ListTags() ([]model.Tag, error) { // SetTagAlias sets a tag with the given ID having the given alias. func SetTagAlias(tagID uint, alias string) error { - // Set a tag as an alias of another tag - var t model.Tag - if err := db.Where("name = ?", alias).First(&t).Error; err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - // create - newTag, err := CreateTag(alias) - if err != nil { - return err - } - t = newTag - } else { + exists, err := ExistsTagByID(tagID) + if err != nil { + return err + } + if !exists { + return model.NewNotFoundError("Tag not found") + } + + exists, err = ExistsTag(alias) + if err != nil { + return err + } + if !exists { + // Create the alias tag if it does not exist + _, err := CreateTag(alias) + if err != nil { return err } } - if t.ID == tagID { - return model.NewRequestError("Tag cannot be an alias of itself") + // Get the alias tag + tag, err := GetTagByName(alias) + if err != nil { + return err } - return db.Model(&t).Update("alias_of", tagID).Error + // If the alias tag is an alias itself, we need to find its root tag + if tag.AliasOf != nil { + tag, err = GetTagByID(*tag.AliasOf) + if err != nil { + return err + } + } + // If the tag has aliases, we need to update their alias_of field + for _, alias := range tag.Aliases { + err := db.Model(&alias).Update("alias_of", tagID).Error + if err != nil { + return err + } + } + tag.Aliases = nil + // A tag cannot be an alias of itself + if tag.ID == tagID { + return model.NewRequestError("A tag cannot be an alias of itself") + } + // Set the alias_of field of the tag + return db.Model(&tag).Update("alias_of", tagID).Error } // RemoveTagAliasOf sets a tag is an independent tag, removing its alias relationship. @@ -180,3 +228,11 @@ func ExistsTag(name string) (bool, error) { } return count > 0, nil } + +func ExistsTagByID(id uint) (bool, error) { + var count int64 + if err := db.Model(&model.Tag{}).Where("id = ?", id).Count(&count).Error; err != nil { + return false, err + } + return count > 0, nil +} diff --git a/server/dao/tag_test.go b/server/dao/tag_test.go new file mode 100644 index 0000000..cc8084a --- /dev/null +++ b/server/dao/tag_test.go @@ -0,0 +1,111 @@ +package dao + +import ( + "os" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestTag(t *testing.T) { + // Create tags + tag1, err := CreateTag("test1") + assert.Nil(t, err) + tag2, err := CreateTag("test2") + assert.Nil(t, err) + tag3, err := CreateTagWithType("test3", "type1") + assert.Nil(t, err) + + // Get tag by ID + fetchedTag, err := GetTagByID(tag1.ID) + assert.Nil(t, err) + assert.Equal(t, tag1.Name, fetchedTag.Name) + + // Get tag by Name + fetchedTag, err = GetTagByName(tag2.Name) + assert.Nil(t, err) + assert.Equal(t, tag2.ID, fetchedTag.ID) + + // Search tags + tags, err := SearchTag("test", true) + assert.Nil(t, err) + assert.GreaterOrEqual(t, len(tags), 3) + + // Update tag + err = SetTagInfo(tag1.ID, "updated description", nil, "updated type") + assert.Nil(t, err) + updatedTag, err := GetTagByID(tag1.ID) + assert.Nil(t, err) + assert.Equal(t, "updated description", updatedTag.Description) + assert.Equal(t, "updated type", updatedTag.Type) + + // Set tag alias + err = SetTagAlias(tag1.ID, tag2.Name) + assert.Nil(t, err) + err = SetTagAlias(tag1.ID, tag3.Name) + assert.Nil(t, err) + err = SetTagAlias(tag1.ID, "test4") + assert.Nil(t, err) + tag4, err := GetTagByName("test4") + assert.Nil(t, err) + tag1, err = GetTagByID(tag1.ID) + assert.Nil(t, err) + aliasesIDs := []uint{} + for _, alias := range tag1.Aliases { + aliasesIDs = append(aliasesIDs, alias.ID) + } + assert.Equal(t, []uint{tag2.ID, tag3.ID, tag4.ID}, aliasesIDs) + + // let a tag which has alias point to another tag + tag5, err := CreateTag("test5") + assert.Nil(t, err) + err = SetTagAlias(tag5.ID, tag1.Name) + assert.Nil(t, err) + tag1, err = GetTagByID(tag1.ID) + assert.Nil(t, err) + tag2, err = GetTagByID(tag2.ID) + assert.Nil(t, err) + tag3, err = GetTagByID(tag3.ID) + assert.Nil(t, err) + tag4, err = GetTagByID(tag4.ID) + assert.Nil(t, err) + tag5, err = GetTagByID(tag5.ID) + assert.Nil(t, err) + assert.Empty(t, tag1.Aliases) + assert.Equal(t, &tag5.ID, tag1.AliasOf) + assert.Equal(t, &tag5.ID, tag2.AliasOf) + assert.Equal(t, &tag5.ID, tag3.AliasOf) + assert.Equal(t, &tag5.ID, tag4.AliasOf) + assert.Nil(t, tag5.AliasOf) + + // Same operation as above, but using `SetTagInfo` + tag6, err := CreateTag("test6") + assert.Nil(t, err) + err = SetTagInfo(tag5.ID, "", &tag6.ID, "") + assert.Nil(t, err) + tag1, err = GetTagByID(tag1.ID) + assert.Nil(t, err) + tag2, err = GetTagByID(tag2.ID) + assert.Nil(t, err) + tag3, err = GetTagByID(tag3.ID) + assert.Nil(t, err) + tag4, err = GetTagByID(tag4.ID) + assert.Nil(t, err) + tag5, err = GetTagByID(tag5.ID) + assert.Nil(t, err) + tag6, err = GetTagByID(tag6.ID) + assert.Nil(t, err) + assert.Equal(t, &tag6.ID, tag1.AliasOf) + assert.Equal(t, &tag6.ID, tag2.AliasOf) + assert.Equal(t, &tag6.ID, tag3.AliasOf) + assert.Equal(t, &tag6.ID, tag4.AliasOf) + assert.Equal(t, &tag6.ID, tag5.AliasOf) + assert.Empty(t, tag5.Aliases) + assert.Nil(t, tag6.AliasOf) + + // cleanup + d, err := db.DB() + assert.Nil(t, err) + _ = d.Close() + _ = os.Remove("test.db") +}