ensure @mentions only sent to team members (#1216)

This commit is contained in:
Doug Lauder 2021-09-20 17:32:47 -04:00 committed by GitHub
parent 1c88d1c986
commit 6daaf3ef59
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 112 additions and 35 deletions

View file

@ -59,3 +59,11 @@ func (da *pluginAPIAdapter) GetUserByID(userID string) (*model.User, error) {
func (da *pluginAPIAdapter) GetUserByUsername(name string) (*model.User, error) {
return da.client.User.GetByUsername(name)
}
func (da *pluginAPIAdapter) GetTeamMember(teamID string, userID string) (*model.TeamMember, error) {
return da.client.Team.GetMember(teamID, userID)
}
func (da *pluginAPIAdapter) GetChannelByID(channelID string) (*model.Channel, error) {
return da.client.Channel.Get(channelID)
}

View file

@ -19,11 +19,17 @@ type PluginAPI interface {
// CreatePost creates a post.
CreatePost(post *model.Post) error
// GetUserByIS gets a user by their ID.
// GetUserByID gets a user by their ID.
GetUserByID(userID string) (*model.User, error)
// GetUserByUsername gets a user by their username.
GetUserByUsername(name string) (*model.User, error)
// GetTeamMember gets a team member by their user id.
GetTeamMember(teamID string, userID string) (*model.TeamMember, error)
// GetChannelByID gets a Channel by its ID.
GetChannelByID(channelID string) (*model.Channel, error)
}
// PluginDelivery provides ability to send notifications to direct message channels via Mattermost plugin API.
@ -42,7 +48,13 @@ func New(botID string, serverRoot string, api PluginAPI) *PluginDelivery {
}
func (pd *PluginDelivery) Deliver(mentionUsername string, extract string, evt notify.BlockChangeEvent) error {
user, err := userFromUsername(pd.api, mentionUsername)
// determine which team the workspace is associated with
teamID, err := pd.getTeamID(evt)
if err != nil {
return fmt.Errorf("cannot determine teamID for block change notification: %w", err)
}
member, err := teamMemberFromUsername(pd.api, mentionUsername, teamID)
if err != nil {
if isErrNotFound(err) {
// not really an error; could just be someone typed "@sometext"
@ -57,7 +69,7 @@ func (pd *PluginDelivery) Deliver(mentionUsername string, extract string, evt no
return fmt.Errorf("cannot find user: %w", err)
}
channel, err := pd.api.GetDirectChannel(user.Id, pd.botID)
channel, err := pd.api.GetDirectChannel(member.UserId, pd.botID)
if err != nil {
return fmt.Errorf("cannot get direct channel: %w", err)
}
@ -70,3 +82,12 @@ func (pd *PluginDelivery) Deliver(mentionUsername string, extract string, evt no
}
return pd.api.CreatePost(post)
}
func (pd *PluginDelivery) getTeamID(evt notify.BlockChangeEvent) (string, error) {
// for now, the workspace ID is also the channel ID
channel, err := pd.api.GetChannelByID(evt.Workspace)
if err != nil {
return "", err
}
return channel.TeamId, nil
}

View file

@ -13,30 +13,36 @@ const (
usernameSpecialChars = ".-_ "
)
func userFromUsername(api PluginAPI, username string) (*mm_model.User, error) {
user, err := api.GetUserByUsername(username)
if err == nil {
return user, nil
func teamMemberFromUsername(api PluginAPI, username string, teamID string) (*mm_model.TeamMember, error) {
// check for usernames that might have trailing punctuation
var user *mm_model.User
var err error
ok := true
trimmed := username
for ok {
user, err = api.GetUserByUsername(trimmed)
if err != nil && !isErrNotFound(err) {
return nil, err
}
if err == nil {
break
}
trimmed, ok = trimUsernameSpecialChar(trimmed)
}
// only continue if the error is `ErrNotFound`
if !isErrNotFound(err) {
if user == nil {
return nil, err
}
// check for usernames in substrings without trailing punctuation
trimmed, ok := trimUsernameSpecialChar(username)
for ; ok; trimmed, ok = trimUsernameSpecialChar(trimmed) {
userFromTrimmed, err2 := api.GetUserByUsername(trimmed)
if err2 != nil && !isErrNotFound(err2) {
return nil, err2
}
if err2 == nil {
return userFromTrimmed, nil
}
// make sure user is member of team.
member, err := api.GetTeamMember(teamID, user.Id)
if err != nil {
return nil, err
}
return nil, err
return member, nil
}
// trimUsernameSpecialChar tries to remove the last character from word if it

View file

@ -11,6 +11,8 @@ import (
)
var (
defTeamID = mm_model.NewId()
user1 = &mm_model.User{
Id: mm_model.NewId(),
Username: "dlauder",
@ -27,42 +29,56 @@ var (
Id: mm_model.NewId(),
Username: "missing_",
}
user5 = &mm_model.User{
Id: mm_model.NewId(),
Username: "wrong_team",
}
mockUsers = map[string]*mm_model.User{
"dlauder": user1,
"steve.mqueen": user2,
"bart_": user3,
"wrong_team": user5,
}
)
func Test_userFromUsername(t *testing.T) {
func userToMember(user *mm_model.User, teamID string) *mm_model.TeamMember {
return &mm_model.TeamMember{
TeamId: teamID,
UserId: user.Id,
}
}
func Test_teamMemberFromUsername(t *testing.T) {
delivery := newPlugAPIMock(mockUsers)
tests := []struct {
name string
uname string
want *mm_model.User
teamID string
want *mm_model.TeamMember
wantErr bool
}{
{name: "user1", uname: user1.Username, want: user1, wantErr: false},
{name: "user1 with period", uname: user1.Username + ".", want: user1, wantErr: false},
{name: "user1 with period plus more", uname: user1.Username + ". ", want: user1, wantErr: false},
{name: "user2 with periods", uname: user2.Username + "...", want: user2, wantErr: false},
{name: "user2 with underscore", uname: user2.Username + "_", want: user2, wantErr: false},
{name: "user2 with hyphen plus more", uname: user2.Username + "- ", want: user2, wantErr: false},
{name: "user2 with hyphen plus all", uname: user2.Username + ".-_ ", want: user2, wantErr: false},
{name: "user3 with underscore", uname: user3.Username + "_", want: user3, wantErr: false},
{name: "user4 missing", uname: user4.Username, want: nil, wantErr: true},
{name: "user1", uname: user1.Username, teamID: defTeamID, want: userToMember(user1, defTeamID), wantErr: false},
{name: "user1 with period", uname: user1.Username + ".", teamID: defTeamID, want: userToMember(user1, defTeamID), wantErr: false},
{name: "user1 with period plus more", uname: user1.Username + ". ", teamID: defTeamID, want: userToMember(user1, defTeamID), wantErr: false},
{name: "user2 with periods", uname: user2.Username + "...", teamID: defTeamID, want: userToMember(user2, defTeamID), wantErr: false},
{name: "user2 with underscore", uname: user2.Username + "_", teamID: defTeamID, want: userToMember(user2, defTeamID), wantErr: false},
{name: "user2 with hyphen plus more", uname: user2.Username + "- ", teamID: defTeamID, want: userToMember(user2, defTeamID), wantErr: false},
{name: "user2 with hyphen plus all", uname: user2.Username + ".-_ ", teamID: defTeamID, want: userToMember(user2, defTeamID), wantErr: false},
{name: "user3 with underscore", uname: user3.Username + "_", teamID: defTeamID, want: userToMember(user3, defTeamID), wantErr: false},
{name: "user4 missing", uname: user4.Username, want: nil, teamID: defTeamID, wantErr: true},
{name: "user5 wrong team", uname: user5.Username, teamID: "bogus_team", want: nil, wantErr: true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := userFromUsername(delivery, tt.uname)
got, err := teamMemberFromUsername(delivery, tt.uname, tt.teamID)
if (err != nil) != tt.wantErr {
t.Errorf("userFromUsername() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("userFromUsername() = %v, want %v", got, tt.want)
t.Errorf("userFromUsername()\ngot:\n%v\nwant:\n%v\n", got, tt.want)
}
})
}
@ -95,7 +111,33 @@ func (m pluginAPIMock) CreatePost(post *mm_model.Post) error {
}
func (m pluginAPIMock) GetUserByID(userID string) (*mm_model.User, error) {
return nil, nil
for _, user := range m.users {
if user.Id == userID {
return user, nil
}
}
return nil, ErrNotFound{}
}
func (m pluginAPIMock) GetTeamMember(teamID string, userID string) (*mm_model.TeamMember, error) {
user, err := m.GetUserByID(userID)
if err != nil {
return nil, err
}
if teamID != defTeamID {
return nil, ErrNotFound{}
}
member := &mm_model.TeamMember{
UserId: user.Id,
TeamId: teamID,
}
return member, nil
}
func (m pluginAPIMock) GetChannelByID(channelID string) (*mm_model.Channel, error) {
return nil, ErrNotFound{}
}
type ErrNotFound struct{}