add mutexs to prevent data races, re-organize a bit

This commit is contained in:
stryan 2021-01-06 19:12:56 -05:00
parent db0af7edcd
commit f4cc9498cb
4 changed files with 67 additions and 37 deletions

View File

@ -12,12 +12,12 @@ import (
"github.com/bwmarrin/discordgo" "github.com/bwmarrin/discordgo"
"layeh.com/gumble/gumble" "layeh.com/gumble/gumble"
"layeh.com/gumble/gumbleutil"
) )
type BridgeState struct { type BridgeState struct {
ActiveConn chan bool ActiveConn chan bool
Connected bool Connected bool
Mode BridgeMode
Client *gumble.Client Client *gumble.Client
DiscordUsers map[string]bool DiscordUsers map[string]bool
MumbleUsers map[string]bool MumbleUsers map[string]bool
@ -46,10 +46,7 @@ func startBridge(discord *discordgo.Session, discordGID string, discordCID strin
if l.BridgeConf.MumbleInsecure { if l.BridgeConf.MumbleInsecure {
tlsConfig.InsecureSkipVerify = true tlsConfig.InsecureSkipVerify = true
} }
l.BridgeConf.Config.Attach(gumbleutil.Listener{
Connect: l.mumbleConnect,
UserChange: l.mumbleUserChange,
})
mumble, err := gumble.DialWithDialer(new(net.Dialer), l.BridgeConf.MumbleAddr, l.BridgeConf.Config, &tlsConfig) mumble, err := gumble.DialWithDialer(new(net.Dialer), l.BridgeConf.MumbleAddr, l.BridgeConf.Config, &tlsConfig)
if err != nil { if err != nil {
@ -100,27 +97,9 @@ func startBridge(discord *discordgo.Session, discordGID string, discordCID strin
} }
} }
}() }()
l.ConnectedLock.Lock()
//Setup initial discord state
g, err := discord.State.Guild(discordGID)
if err != nil {
log.Println("error finding guild")
panic(err)
}
for _, vs := range g.VoiceStates {
if vs.ChannelID == discordCID {
l.Bridge.DiscordUserCount = l.Bridge.DiscordUserCount + 1
u, err := discord.User(vs.UserID)
if err != nil {
log.Println("error looking up username")
l.Bridge.DiscordUsers[u.Username] = true
l.Bridge.Client.Do(func() {
l.Bridge.Client.Self.Channel.Send(fmt.Sprintf("%v has joined Discord channel\n", u.Username), false)
})
}
}
}
l.Bridge.Connected = true l.Bridge.Connected = true
l.ConnectedLock.Unlock()
select { select {
case sig := <-c: case sig := <-c:
@ -153,6 +132,8 @@ func discordStatusUpdate(dg *discordgo.Session, host, port string, l *Listener)
log.Printf("error pinging mumble server %v\n", err) log.Printf("error pinging mumble server %v\n", err)
dg.UpdateListeningStatus("an error pinging mumble") dg.UpdateListeningStatus("an error pinging mumble")
} else { } else {
l.UserCountLock.Lock()
l.ConnectedLock.Lock()
curr = resp.ConnectedUsers curr = resp.ConnectedUsers
if l.Bridge.Connected { if l.Bridge.Connected {
curr = curr - 1 curr = curr - 1
@ -169,6 +150,8 @@ func discordStatusUpdate(dg *discordgo.Session, host, port string, l *Listener)
status = fmt.Sprintf("%v users in Mumble\n", curr) status = fmt.Sprintf("%v users in Mumble\n", curr)
} }
} }
l.ConnectedLock.Unlock()
l.UserCountLock.Unlock()
dg.UpdateListeningStatus(status) dg.UpdateListeningStatus(status)
} }
} }
@ -184,6 +167,7 @@ func AutoBridge(s *discordgo.Session, l *Listener) {
return return
} }
time.Sleep(3 * time.Second) time.Sleep(3 * time.Second)
l.UserCountLock.Lock()
if !l.Bridge.Connected && l.Bridge.MumbleUserCount > 0 && l.Bridge.DiscordUserCount > 0 { if !l.Bridge.Connected && l.Bridge.MumbleUserCount > 0 && l.Bridge.DiscordUserCount > 0 {
log.Println("users detected in mumble and discord, bridging") log.Println("users detected in mumble and discord, bridging")
die := make(chan bool) die := make(chan bool)
@ -196,5 +180,6 @@ func AutoBridge(s *discordgo.Session, l *Listener) {
MumbleReset() MumbleReset()
DiscordReset() DiscordReset()
} }
l.UserCountLock.Unlock()
} }
} }

View File

@ -23,7 +23,6 @@ type BridgeConfig struct {
MumbleAddr string MumbleAddr string
MumbleInsecure bool MumbleInsecure bool
MumbleChannel string MumbleChannel string
Mode BridgeMode
Command string Command string
GID string GID string
CID string CID string

View File

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"log" "log"
"strings" "strings"
"sync"
"time" "time"
"github.com/bwmarrin/discordgo" "github.com/bwmarrin/discordgo"
@ -11,17 +12,43 @@ import (
) )
type Listener struct { type Listener struct {
BridgeConf *BridgeConfig BridgeConf *BridgeConfig
Bridge *BridgeState Bridge *BridgeState
UserCountLock *sync.Mutex
ConnectedLock *sync.Mutex
} }
func (l *Listener) ready(s *discordgo.Session, event *discordgo.Ready) { func (l *Listener) ready(s *discordgo.Session, event *discordgo.Ready) {
log.Println("READY event registered") log.Println("READY event registered")
//Setup initial discord state
var g *discordgo.Guild
for _, i := range event.Guilds {
if i.ID == l.BridgeConf.GID {
g = i
}
}
for _, vs := range g.VoiceStates {
if vs.ChannelID == l.BridgeConf.CID {
l.UserCountLock.Lock()
l.Bridge.DiscordUserCount = l.Bridge.DiscordUserCount + 1
u, err := s.User(vs.UserID)
if err != nil {
log.Println("error looking up username")
}
l.Bridge.DiscordUsers[u.Username] = true
if l.Bridge.Connected {
l.Bridge.Client.Do(func() {
l.Bridge.Client.Self.Channel.Send(fmt.Sprintf("%v has joined Discord channel\n", u.Username), false)
})
}
l.UserCountLock.Unlock()
}
}
} }
func (l *Listener) messageCreate(s *discordgo.Session, m *discordgo.MessageCreate) { func (l *Listener) messageCreate(s *discordgo.Session, m *discordgo.MessageCreate) {
if l.BridgeConf.Mode == BridgeModeConstant { if l.Bridge.Mode == BridgeModeConstant {
return return
} }
@ -119,13 +146,13 @@ func (l *Listener) messageCreate(s *discordgo.Session, m *discordgo.MessageCreat
} }
if strings.HasPrefix(m.Content, prefix+" auto") { if strings.HasPrefix(m.Content, prefix+" auto") {
if l.BridgeConf.Mode != BridgeModeAuto { if l.Bridge.Mode != BridgeModeAuto {
l.BridgeConf.Mode = BridgeModeAuto l.Bridge.Mode = BridgeModeAuto
l.Bridge.AutoChan = make(chan bool) l.Bridge.AutoChan = make(chan bool)
go AutoBridge(s, l) go AutoBridge(s, l)
} else { } else {
l.Bridge.AutoChan <- true l.Bridge.AutoChan <- true
l.BridgeConf.Mode = BridgeModeManual l.Bridge.Mode = BridgeModeManual
} }
} }
} }
@ -145,6 +172,7 @@ func (l *Listener) guildCreate(s *discordgo.Session, event *discordgo.GuildCreat
} }
func (l *Listener) voiceUpdate(s *discordgo.Session, event *discordgo.VoiceStateUpdate) { func (l *Listener) voiceUpdate(s *discordgo.Session, event *discordgo.VoiceStateUpdate) {
l.UserCountLock.Lock()
if event.GuildID == l.BridgeConf.GID { if event.GuildID == l.BridgeConf.GID {
if event.ChannelID == l.BridgeConf.CID { if event.ChannelID == l.BridgeConf.CID {
//get user //get user
@ -155,17 +183,21 @@ func (l *Listener) voiceUpdate(s *discordgo.Session, event *discordgo.VoiceState
//check to see if actually new user //check to see if actually new user
if l.Bridge.DiscordUsers[u.Username] { if l.Bridge.DiscordUsers[u.Username] {
//not actually new user //not actually new user
l.UserCountLock.Unlock()
return return
} }
log.Println("user joined watched discord channel") log.Println("user joined watched discord channel")
l.ConnectedLock.Lock()
if l.Bridge.Connected { if l.Bridge.Connected {
l.Bridge.Client.Do(func() { l.Bridge.Client.Do(func() {
l.Bridge.Client.Self.Channel.Send(fmt.Sprintf("%v has joined Discord channel\n", u.Username), false) l.Bridge.Client.Self.Channel.Send(fmt.Sprintf("%v has joined Discord channel\n", u.Username), false)
}) })
} }
l.ConnectedLock.Unlock()
l.Bridge.DiscordUsers[u.Username] = true l.Bridge.DiscordUsers[u.Username] = true
log.Println(l.Bridge.DiscordUsers) log.Println(l.Bridge.DiscordUsers)
l.Bridge.DiscordUserCount = l.Bridge.DiscordUserCount + 1 l.Bridge.DiscordUserCount = l.Bridge.DiscordUserCount + 1
l.UserCountLock.Unlock()
} }
if event.ChannelID == "" { if event.ChannelID == "" {
//leave event, trigger recount of active users //leave event, trigger recount of active users
@ -173,6 +205,7 @@ func (l *Listener) voiceUpdate(s *discordgo.Session, event *discordgo.VoiceState
g, err := s.State.Guild(event.GuildID) g, err := s.State.Guild(event.GuildID)
if err != nil { if err != nil {
// Could not find guild. // Could not find guild.
l.UserCountLock.Unlock()
return return
} }
@ -190,12 +223,15 @@ func (l *Listener) voiceUpdate(s *discordgo.Session, event *discordgo.VoiceState
} }
delete(l.Bridge.DiscordUsers, u.Username) delete(l.Bridge.DiscordUsers, u.Username)
log.Println("user left watched discord channel") log.Println("user left watched discord channel")
l.ConnectedLock.Lock()
if l.Bridge.Connected { if l.Bridge.Connected {
l.Bridge.Client.Do(func() { l.Bridge.Client.Do(func() {
l.Bridge.Client.Self.Channel.Send(fmt.Sprintf("%v has left Discord channel\n", u.Username), false) l.Bridge.Client.Self.Channel.Send(fmt.Sprintf("%v has left Discord channel\n", u.Username), false)
}) })
} }
l.ConnectedLock.Unlock()
l.Bridge.DiscordUserCount = count l.Bridge.DiscordUserCount = count
l.UserCountLock.Unlock()
} }
} }
@ -214,6 +250,7 @@ func (l *Listener) mumbleConnect(e *gumble.ConnectEvent) {
} }
func (l *Listener) mumbleUserChange(e *gumble.UserChangeEvent) { func (l *Listener) mumbleUserChange(e *gumble.UserChangeEvent) {
l.UserCountLock.Lock()
if e.Type.Has(gumble.UserChangeConnected) || e.Type.Has(gumble.UserChangeChannel) || e.Type.Has(gumble.UserChangeDisconnected) { if e.Type.Has(gumble.UserChangeConnected) || e.Type.Has(gumble.UserChangeChannel) || e.Type.Has(gumble.UserChangeDisconnected) {
l.Bridge.MumbleUsers = make(map[string]bool) l.Bridge.MumbleUsers = make(map[string]bool)
for _, user := range l.Bridge.Client.Self.Channel.Users { for _, user := range l.Bridge.Client.Self.Channel.Users {
@ -225,4 +262,5 @@ func (l *Listener) mumbleUserChange(e *gumble.UserChangeEvent) {
} }
} }
} }
l.UserCountLock.Unlock()
} }

18
main.go
View File

@ -6,12 +6,14 @@ import (
"os" "os"
"os/signal" "os/signal"
"strconv" "strconv"
"sync"
"syscall" "syscall"
"time" "time"
"github.com/bwmarrin/discordgo" "github.com/bwmarrin/discordgo"
"github.com/joho/godotenv" "github.com/joho/godotenv"
"layeh.com/gumble/gumble" "layeh.com/gumble/gumble"
"layeh.com/gumble/gumbleutil"
_ "layeh.com/gumble/opus" _ "layeh.com/gumble/opus"
) )
@ -75,7 +77,6 @@ func main() {
MumbleAddr: *mumbleAddr + ":" + strconv.Itoa(*mumblePort), MumbleAddr: *mumbleAddr + ":" + strconv.Itoa(*mumblePort),
MumbleInsecure: *mumbleInsecure, MumbleInsecure: *mumbleInsecure,
MumbleChannel: *mumbleChannel, MumbleChannel: *mumbleChannel,
Mode: -1,
Command: *discordCommand, Command: *discordCommand,
GID: *discordGID, GID: *discordGID,
CID: *discordCID, CID: *discordCID,
@ -86,8 +87,11 @@ func main() {
MumbleUserCount: 0, MumbleUserCount: 0,
DiscordUserCount: 0, DiscordUserCount: 0,
DiscordUsers: make(map[string]bool), DiscordUsers: make(map[string]bool),
MumbleUsers: make(map[string]bool),
} }
l := &Listener{BridgeConf, Bridge} ul := &sync.Mutex{}
cl := &sync.Mutex{}
l := &Listener{BridgeConf, Bridge, ul, cl}
// Discord setup // Discord setup
// Open Websocket // Open Websocket
@ -100,6 +104,10 @@ func main() {
discord.AddHandler(l.guildCreate) discord.AddHandler(l.guildCreate)
discord.AddHandler(l.voiceUpdate) discord.AddHandler(l.voiceUpdate)
err = discord.Open() err = discord.Open()
l.BridgeConf.Config.Attach(gumbleutil.Listener{
Connect: l.mumbleConnect,
UserChange: l.mumbleUserChange,
})
if err != nil { if err != nil {
log.Println(err) log.Println(err)
return return
@ -113,14 +121,14 @@ func main() {
case "auto": case "auto":
log.Println("bridge starting in automatic mode") log.Println("bridge starting in automatic mode")
Bridge.AutoChan = make(chan bool) Bridge.AutoChan = make(chan bool)
BridgeConf.Mode = BridgeModeAuto Bridge.Mode = BridgeModeAuto
go AutoBridge(discord, l) go AutoBridge(discord, l)
case "manual": case "manual":
log.Println("bridge starting in manual mode") log.Println("bridge starting in manual mode")
BridgeConf.Mode = BridgeModeManual Bridge.Mode = BridgeModeManual
case "constant": case "constant":
log.Println("bridge starting in constant mode") log.Println("bridge starting in constant mode")
BridgeConf.Mode = BridgeModeConstant Bridge.Mode = BridgeModeConstant
go startBridge(discord, *discordGID, *discordCID, l, make(chan bool)) go startBridge(discord, *discordGID, *discordCID, l, make(chan bool))
default: default:
discord.Close() discord.Close()