From 6daaf3ef59d41794cc8e149da4da021fd7a63495 Mon Sep 17 00:00:00 2001 From: Doug Lauder Date: Mon, 20 Sep 2021 17:32:47 -0400 Subject: [PATCH] ensure @mentions only sent to team members (#1216) --- mattermost-plugin/server/notifications.go | 8 +++ .../notify/plugindelivery/plugin_delivery.go | 27 ++++++- server/services/notify/plugindelivery/user.go | 42 ++++++----- .../notify/plugindelivery/user_test.go | 70 +++++++++++++++---- 4 files changed, 112 insertions(+), 35 deletions(-) diff --git a/mattermost-plugin/server/notifications.go b/mattermost-plugin/server/notifications.go index abe890dc5..6134982f4 100644 --- a/mattermost-plugin/server/notifications.go +++ b/mattermost-plugin/server/notifications.go @@ -59,3 +59,11 @@ func (da *pluginAPIAdapter) GetUserByID(userID string) (*model.User, error) { func (da *pluginAPIAdapter) GetUserByUsername(name string) (*model.User, error) { return da.client.User.GetByUsername(name) } + +func (da *pluginAPIAdapter) GetTeamMember(teamID string, userID string) (*model.TeamMember, error) { + return da.client.Team.GetMember(teamID, userID) +} + +func (da *pluginAPIAdapter) GetChannelByID(channelID string) (*model.Channel, error) { + return da.client.Channel.Get(channelID) +} diff --git a/server/services/notify/plugindelivery/plugin_delivery.go b/server/services/notify/plugindelivery/plugin_delivery.go index 8281d50e5..75e4fb3bd 100644 --- a/server/services/notify/plugindelivery/plugin_delivery.go +++ b/server/services/notify/plugindelivery/plugin_delivery.go @@ -19,11 +19,17 @@ type PluginAPI interface { // CreatePost creates a post. CreatePost(post *model.Post) error - // GetUserByIS gets a user by their ID. + // GetUserByID gets a user by their ID. GetUserByID(userID string) (*model.User, error) // GetUserByUsername gets a user by their username. GetUserByUsername(name string) (*model.User, error) + + // GetTeamMember gets a team member by their user id. + GetTeamMember(teamID string, userID string) (*model.TeamMember, error) + + // GetChannelByID gets a Channel by its ID. + GetChannelByID(channelID string) (*model.Channel, error) } // PluginDelivery provides ability to send notifications to direct message channels via Mattermost plugin API. @@ -42,7 +48,13 @@ func New(botID string, serverRoot string, api PluginAPI) *PluginDelivery { } func (pd *PluginDelivery) Deliver(mentionUsername string, extract string, evt notify.BlockChangeEvent) error { - user, err := userFromUsername(pd.api, mentionUsername) + // determine which team the workspace is associated with + teamID, err := pd.getTeamID(evt) + if err != nil { + return fmt.Errorf("cannot determine teamID for block change notification: %w", err) + } + + member, err := teamMemberFromUsername(pd.api, mentionUsername, teamID) if err != nil { if isErrNotFound(err) { // not really an error; could just be someone typed "@sometext" @@ -57,7 +69,7 @@ func (pd *PluginDelivery) Deliver(mentionUsername string, extract string, evt no return fmt.Errorf("cannot find user: %w", err) } - channel, err := pd.api.GetDirectChannel(user.Id, pd.botID) + channel, err := pd.api.GetDirectChannel(member.UserId, pd.botID) if err != nil { return fmt.Errorf("cannot get direct channel: %w", err) } @@ -70,3 +82,12 @@ func (pd *PluginDelivery) Deliver(mentionUsername string, extract string, evt no } return pd.api.CreatePost(post) } + +func (pd *PluginDelivery) getTeamID(evt notify.BlockChangeEvent) (string, error) { + // for now, the workspace ID is also the channel ID + channel, err := pd.api.GetChannelByID(evt.Workspace) + if err != nil { + return "", err + } + return channel.TeamId, nil +} diff --git a/server/services/notify/plugindelivery/user.go b/server/services/notify/plugindelivery/user.go index 2ad7baa1f..f9efec030 100644 --- a/server/services/notify/plugindelivery/user.go +++ b/server/services/notify/plugindelivery/user.go @@ -13,30 +13,36 @@ const ( usernameSpecialChars = ".-_ " ) -func userFromUsername(api PluginAPI, username string) (*mm_model.User, error) { - user, err := api.GetUserByUsername(username) - if err == nil { - return user, nil +func teamMemberFromUsername(api PluginAPI, username string, teamID string) (*mm_model.TeamMember, error) { + // check for usernames that might have trailing punctuation + var user *mm_model.User + var err error + ok := true + trimmed := username + for ok { + user, err = api.GetUserByUsername(trimmed) + if err != nil && !isErrNotFound(err) { + return nil, err + } + + if err == nil { + break + } + + trimmed, ok = trimUsernameSpecialChar(trimmed) } - // only continue if the error is `ErrNotFound` - if !isErrNotFound(err) { + if user == nil { return nil, err } - // check for usernames in substrings without trailing punctuation - trimmed, ok := trimUsernameSpecialChar(username) - for ; ok; trimmed, ok = trimUsernameSpecialChar(trimmed) { - userFromTrimmed, err2 := api.GetUserByUsername(trimmed) - if err2 != nil && !isErrNotFound(err2) { - return nil, err2 - } - - if err2 == nil { - return userFromTrimmed, nil - } + // make sure user is member of team. + member, err := api.GetTeamMember(teamID, user.Id) + if err != nil { + return nil, err } - return nil, err + + return member, nil } // trimUsernameSpecialChar tries to remove the last character from word if it diff --git a/server/services/notify/plugindelivery/user_test.go b/server/services/notify/plugindelivery/user_test.go index 4347a40a7..1577cb857 100644 --- a/server/services/notify/plugindelivery/user_test.go +++ b/server/services/notify/plugindelivery/user_test.go @@ -11,6 +11,8 @@ import ( ) var ( + defTeamID = mm_model.NewId() + user1 = &mm_model.User{ Id: mm_model.NewId(), Username: "dlauder", @@ -27,42 +29,56 @@ var ( Id: mm_model.NewId(), Username: "missing_", } + user5 = &mm_model.User{ + Id: mm_model.NewId(), + Username: "wrong_team", + } mockUsers = map[string]*mm_model.User{ "dlauder": user1, "steve.mqueen": user2, "bart_": user3, + "wrong_team": user5, } ) -func Test_userFromUsername(t *testing.T) { +func userToMember(user *mm_model.User, teamID string) *mm_model.TeamMember { + return &mm_model.TeamMember{ + TeamId: teamID, + UserId: user.Id, + } +} + +func Test_teamMemberFromUsername(t *testing.T) { delivery := newPlugAPIMock(mockUsers) tests := []struct { name string uname string - want *mm_model.User + teamID string + want *mm_model.TeamMember wantErr bool }{ - {name: "user1", uname: user1.Username, want: user1, wantErr: false}, - {name: "user1 with period", uname: user1.Username + ".", want: user1, wantErr: false}, - {name: "user1 with period plus more", uname: user1.Username + ". ", want: user1, wantErr: false}, - {name: "user2 with periods", uname: user2.Username + "...", want: user2, wantErr: false}, - {name: "user2 with underscore", uname: user2.Username + "_", want: user2, wantErr: false}, - {name: "user2 with hyphen plus more", uname: user2.Username + "- ", want: user2, wantErr: false}, - {name: "user2 with hyphen plus all", uname: user2.Username + ".-_ ", want: user2, wantErr: false}, - {name: "user3 with underscore", uname: user3.Username + "_", want: user3, wantErr: false}, - {name: "user4 missing", uname: user4.Username, want: nil, wantErr: true}, + {name: "user1", uname: user1.Username, teamID: defTeamID, want: userToMember(user1, defTeamID), wantErr: false}, + {name: "user1 with period", uname: user1.Username + ".", teamID: defTeamID, want: userToMember(user1, defTeamID), wantErr: false}, + {name: "user1 with period plus more", uname: user1.Username + ". ", teamID: defTeamID, want: userToMember(user1, defTeamID), wantErr: false}, + {name: "user2 with periods", uname: user2.Username + "...", teamID: defTeamID, want: userToMember(user2, defTeamID), wantErr: false}, + {name: "user2 with underscore", uname: user2.Username + "_", teamID: defTeamID, want: userToMember(user2, defTeamID), wantErr: false}, + {name: "user2 with hyphen plus more", uname: user2.Username + "- ", teamID: defTeamID, want: userToMember(user2, defTeamID), wantErr: false}, + {name: "user2 with hyphen plus all", uname: user2.Username + ".-_ ", teamID: defTeamID, want: userToMember(user2, defTeamID), wantErr: false}, + {name: "user3 with underscore", uname: user3.Username + "_", teamID: defTeamID, want: userToMember(user3, defTeamID), wantErr: false}, + {name: "user4 missing", uname: user4.Username, want: nil, teamID: defTeamID, wantErr: true}, + {name: "user5 wrong team", uname: user5.Username, teamID: "bogus_team", want: nil, wantErr: true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := userFromUsername(delivery, tt.uname) + got, err := teamMemberFromUsername(delivery, tt.uname, tt.teamID) if (err != nil) != tt.wantErr { t.Errorf("userFromUsername() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { - t.Errorf("userFromUsername() = %v, want %v", got, tt.want) + t.Errorf("userFromUsername()\ngot:\n%v\nwant:\n%v\n", got, tt.want) } }) } @@ -95,7 +111,33 @@ func (m pluginAPIMock) CreatePost(post *mm_model.Post) error { } func (m pluginAPIMock) GetUserByID(userID string) (*mm_model.User, error) { - return nil, nil + for _, user := range m.users { + if user.Id == userID { + return user, nil + } + } + return nil, ErrNotFound{} +} + +func (m pluginAPIMock) GetTeamMember(teamID string, userID string) (*mm_model.TeamMember, error) { + user, err := m.GetUserByID(userID) + if err != nil { + return nil, err + } + + if teamID != defTeamID { + return nil, ErrNotFound{} + } + + member := &mm_model.TeamMember{ + UserId: user.Id, + TeamId: teamID, + } + return member, nil +} + +func (m pluginAPIMock) GetChannelByID(channelID string) (*mm_model.Channel, error) { + return nil, ErrNotFound{} } type ErrNotFound struct{}