ensure @mentions only sent to team members (#1216)
This commit is contained in:
parent
1c88d1c986
commit
6daaf3ef59
4 changed files with 112 additions and 35 deletions
|
@ -59,3 +59,11 @@ func (da *pluginAPIAdapter) GetUserByID(userID string) (*model.User, error) {
|
|||
func (da *pluginAPIAdapter) GetUserByUsername(name string) (*model.User, error) {
|
||||
return da.client.User.GetByUsername(name)
|
||||
}
|
||||
|
||||
func (da *pluginAPIAdapter) GetTeamMember(teamID string, userID string) (*model.TeamMember, error) {
|
||||
return da.client.Team.GetMember(teamID, userID)
|
||||
}
|
||||
|
||||
func (da *pluginAPIAdapter) GetChannelByID(channelID string) (*model.Channel, error) {
|
||||
return da.client.Channel.Get(channelID)
|
||||
}
|
||||
|
|
|
@ -19,11 +19,17 @@ type PluginAPI interface {
|
|||
// CreatePost creates a post.
|
||||
CreatePost(post *model.Post) error
|
||||
|
||||
// GetUserByIS gets a user by their ID.
|
||||
// GetUserByID gets a user by their ID.
|
||||
GetUserByID(userID string) (*model.User, error)
|
||||
|
||||
// GetUserByUsername gets a user by their username.
|
||||
GetUserByUsername(name string) (*model.User, error)
|
||||
|
||||
// GetTeamMember gets a team member by their user id.
|
||||
GetTeamMember(teamID string, userID string) (*model.TeamMember, error)
|
||||
|
||||
// GetChannelByID gets a Channel by its ID.
|
||||
GetChannelByID(channelID string) (*model.Channel, error)
|
||||
}
|
||||
|
||||
// PluginDelivery provides ability to send notifications to direct message channels via Mattermost plugin API.
|
||||
|
@ -42,7 +48,13 @@ func New(botID string, serverRoot string, api PluginAPI) *PluginDelivery {
|
|||
}
|
||||
|
||||
func (pd *PluginDelivery) Deliver(mentionUsername string, extract string, evt notify.BlockChangeEvent) error {
|
||||
user, err := userFromUsername(pd.api, mentionUsername)
|
||||
// determine which team the workspace is associated with
|
||||
teamID, err := pd.getTeamID(evt)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot determine teamID for block change notification: %w", err)
|
||||
}
|
||||
|
||||
member, err := teamMemberFromUsername(pd.api, mentionUsername, teamID)
|
||||
if err != nil {
|
||||
if isErrNotFound(err) {
|
||||
// not really an error; could just be someone typed "@sometext"
|
||||
|
@ -57,7 +69,7 @@ func (pd *PluginDelivery) Deliver(mentionUsername string, extract string, evt no
|
|||
return fmt.Errorf("cannot find user: %w", err)
|
||||
}
|
||||
|
||||
channel, err := pd.api.GetDirectChannel(user.Id, pd.botID)
|
||||
channel, err := pd.api.GetDirectChannel(member.UserId, pd.botID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot get direct channel: %w", err)
|
||||
}
|
||||
|
@ -70,3 +82,12 @@ func (pd *PluginDelivery) Deliver(mentionUsername string, extract string, evt no
|
|||
}
|
||||
return pd.api.CreatePost(post)
|
||||
}
|
||||
|
||||
func (pd *PluginDelivery) getTeamID(evt notify.BlockChangeEvent) (string, error) {
|
||||
// for now, the workspace ID is also the channel ID
|
||||
channel, err := pd.api.GetChannelByID(evt.Workspace)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return channel.TeamId, nil
|
||||
}
|
||||
|
|
|
@ -13,30 +13,36 @@ const (
|
|||
usernameSpecialChars = ".-_ "
|
||||
)
|
||||
|
||||
func userFromUsername(api PluginAPI, username string) (*mm_model.User, error) {
|
||||
user, err := api.GetUserByUsername(username)
|
||||
if err == nil {
|
||||
return user, nil
|
||||
func teamMemberFromUsername(api PluginAPI, username string, teamID string) (*mm_model.TeamMember, error) {
|
||||
// check for usernames that might have trailing punctuation
|
||||
var user *mm_model.User
|
||||
var err error
|
||||
ok := true
|
||||
trimmed := username
|
||||
for ok {
|
||||
user, err = api.GetUserByUsername(trimmed)
|
||||
if err != nil && !isErrNotFound(err) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
|
||||
trimmed, ok = trimUsernameSpecialChar(trimmed)
|
||||
}
|
||||
|
||||
// only continue if the error is `ErrNotFound`
|
||||
if !isErrNotFound(err) {
|
||||
if user == nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// check for usernames in substrings without trailing punctuation
|
||||
trimmed, ok := trimUsernameSpecialChar(username)
|
||||
for ; ok; trimmed, ok = trimUsernameSpecialChar(trimmed) {
|
||||
userFromTrimmed, err2 := api.GetUserByUsername(trimmed)
|
||||
if err2 != nil && !isErrNotFound(err2) {
|
||||
return nil, err2
|
||||
}
|
||||
|
||||
if err2 == nil {
|
||||
return userFromTrimmed, nil
|
||||
}
|
||||
// make sure user is member of team.
|
||||
member, err := api.GetTeamMember(teamID, user.Id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, err
|
||||
|
||||
return member, nil
|
||||
}
|
||||
|
||||
// trimUsernameSpecialChar tries to remove the last character from word if it
|
||||
|
|
|
@ -11,6 +11,8 @@ import (
|
|||
)
|
||||
|
||||
var (
|
||||
defTeamID = mm_model.NewId()
|
||||
|
||||
user1 = &mm_model.User{
|
||||
Id: mm_model.NewId(),
|
||||
Username: "dlauder",
|
||||
|
@ -27,42 +29,56 @@ var (
|
|||
Id: mm_model.NewId(),
|
||||
Username: "missing_",
|
||||
}
|
||||
user5 = &mm_model.User{
|
||||
Id: mm_model.NewId(),
|
||||
Username: "wrong_team",
|
||||
}
|
||||
|
||||
mockUsers = map[string]*mm_model.User{
|
||||
"dlauder": user1,
|
||||
"steve.mqueen": user2,
|
||||
"bart_": user3,
|
||||
"wrong_team": user5,
|
||||
}
|
||||
)
|
||||
|
||||
func Test_userFromUsername(t *testing.T) {
|
||||
func userToMember(user *mm_model.User, teamID string) *mm_model.TeamMember {
|
||||
return &mm_model.TeamMember{
|
||||
TeamId: teamID,
|
||||
UserId: user.Id,
|
||||
}
|
||||
}
|
||||
|
||||
func Test_teamMemberFromUsername(t *testing.T) {
|
||||
delivery := newPlugAPIMock(mockUsers)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
uname string
|
||||
want *mm_model.User
|
||||
teamID string
|
||||
want *mm_model.TeamMember
|
||||
wantErr bool
|
||||
}{
|
||||
{name: "user1", uname: user1.Username, want: user1, wantErr: false},
|
||||
{name: "user1 with period", uname: user1.Username + ".", want: user1, wantErr: false},
|
||||
{name: "user1 with period plus more", uname: user1.Username + ". ", want: user1, wantErr: false},
|
||||
{name: "user2 with periods", uname: user2.Username + "...", want: user2, wantErr: false},
|
||||
{name: "user2 with underscore", uname: user2.Username + "_", want: user2, wantErr: false},
|
||||
{name: "user2 with hyphen plus more", uname: user2.Username + "- ", want: user2, wantErr: false},
|
||||
{name: "user2 with hyphen plus all", uname: user2.Username + ".-_ ", want: user2, wantErr: false},
|
||||
{name: "user3 with underscore", uname: user3.Username + "_", want: user3, wantErr: false},
|
||||
{name: "user4 missing", uname: user4.Username, want: nil, wantErr: true},
|
||||
{name: "user1", uname: user1.Username, teamID: defTeamID, want: userToMember(user1, defTeamID), wantErr: false},
|
||||
{name: "user1 with period", uname: user1.Username + ".", teamID: defTeamID, want: userToMember(user1, defTeamID), wantErr: false},
|
||||
{name: "user1 with period plus more", uname: user1.Username + ". ", teamID: defTeamID, want: userToMember(user1, defTeamID), wantErr: false},
|
||||
{name: "user2 with periods", uname: user2.Username + "...", teamID: defTeamID, want: userToMember(user2, defTeamID), wantErr: false},
|
||||
{name: "user2 with underscore", uname: user2.Username + "_", teamID: defTeamID, want: userToMember(user2, defTeamID), wantErr: false},
|
||||
{name: "user2 with hyphen plus more", uname: user2.Username + "- ", teamID: defTeamID, want: userToMember(user2, defTeamID), wantErr: false},
|
||||
{name: "user2 with hyphen plus all", uname: user2.Username + ".-_ ", teamID: defTeamID, want: userToMember(user2, defTeamID), wantErr: false},
|
||||
{name: "user3 with underscore", uname: user3.Username + "_", teamID: defTeamID, want: userToMember(user3, defTeamID), wantErr: false},
|
||||
{name: "user4 missing", uname: user4.Username, want: nil, teamID: defTeamID, wantErr: true},
|
||||
{name: "user5 wrong team", uname: user5.Username, teamID: "bogus_team", want: nil, wantErr: true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := userFromUsername(delivery, tt.uname)
|
||||
got, err := teamMemberFromUsername(delivery, tt.uname, tt.teamID)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("userFromUsername() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("userFromUsername() = %v, want %v", got, tt.want)
|
||||
t.Errorf("userFromUsername()\ngot:\n%v\nwant:\n%v\n", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -95,7 +111,33 @@ func (m pluginAPIMock) CreatePost(post *mm_model.Post) error {
|
|||
}
|
||||
|
||||
func (m pluginAPIMock) GetUserByID(userID string) (*mm_model.User, error) {
|
||||
return nil, nil
|
||||
for _, user := range m.users {
|
||||
if user.Id == userID {
|
||||
return user, nil
|
||||
}
|
||||
}
|
||||
return nil, ErrNotFound{}
|
||||
}
|
||||
|
||||
func (m pluginAPIMock) GetTeamMember(teamID string, userID string) (*mm_model.TeamMember, error) {
|
||||
user, err := m.GetUserByID(userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if teamID != defTeamID {
|
||||
return nil, ErrNotFound{}
|
||||
}
|
||||
|
||||
member := &mm_model.TeamMember{
|
||||
UserId: user.Id,
|
||||
TeamId: teamID,
|
||||
}
|
||||
return member, nil
|
||||
}
|
||||
|
||||
func (m pluginAPIMock) GetChannelByID(channelID string) (*mm_model.Channel, error) {
|
||||
return nil, ErrNotFound{}
|
||||
}
|
||||
|
||||
type ErrNotFound struct{}
|
||||
|
|
Loading…
Reference in a new issue