Provide more useful error message if MySQL version is not expected
[pstop.git] / main.go
1 // Top like progream which collects information from MySQL's
2 // performance_schema database.
3 package main
4
5 import (
6         "database/sql"
7         "errors"
8         "flag"
9         "fmt"
10         "log"
11         "os"
12         "os/signal"
13         "regexp"
14         "runtime/pprof"
15         "syscall"
16         "time"
17
18         _ "github.com/go-sql-driver/mysql"
19         "github.com/nsf/termbox-go"
20
21         "github.com/sjmudd/mysql_defaults_file"
22         "github.com/sjmudd/pstop/lib"
23         "github.com/sjmudd/pstop/state"
24         "github.com/sjmudd/pstop/version"
25 )
26
27 const (
28         sql_driver = "mysql"
29         db         = "performance_schema"
30 )
31
32 var (
33         flag_version = flag.Bool("version", false, "Show the version of "+lib.MyName())
34         flag_debug   = flag.Bool("debug", false, "Enabling debug logging")
35         flag_help    = flag.Bool("help", false, "Provide some help for "+lib.MyName())
36         cpuprofile   = flag.String("cpuprofile", "", "write cpu profile to file")
37
38         re_valid_version = regexp.MustCompile(`^(5\.[67]\.|10\.[01])`)
39 )
40
41 func get_db_handle() *sql.DB {
42         var err error
43         var dbh *sql.DB
44         lib.Logger.Println("get_db_handle() connecting to database")
45
46         dbh, err = mysql_defaults_file.OpenUsingDefaultsFile(sql_driver, "", "performance_schema")
47         if err != nil {
48                 log.Fatal(err)
49         }
50         if err = dbh.Ping(); err != nil {
51                 log.Fatal(err)
52         }
53
54         return dbh
55 }
56
57 // make chan for termbox events and run a poller to send events to the channel
58 // - return the channel
59 func new_tb_chan() chan termbox.Event {
60         termboxChan := make(chan termbox.Event)
61         go func() {
62                 for {
63                         termboxChan <- termbox.PollEvent()
64                 }
65         }()
66         return termboxChan
67 }
68
69 func usage() {
70         fmt.Println(lib.MyName() + " - " + lib.Copyright())
71         fmt.Println("")
72         fmt.Println("Top-like program to show MySQL activity by using information collected")
73         fmt.Println("from performance_schema.")
74         fmt.Println("")
75         fmt.Println("Usage: " + lib.MyName() + " <options>")
76         fmt.Println("")
77         fmt.Println("Options:")
78         fmt.Println("-help      show this help message")
79         fmt.Println("-version   show the version")
80 }
81
82 // pstop requires MySQL 5.6+ or MariaDB 10.0+. Check the version
83 // rather than giving an error message if the requires P_S tables can't
84 // be found.
85 func validate_mysql_version(dbh *sql.DB) error {
86         lib.Logger.Println("validate_mysql_version()")
87
88         _, mysql_version := lib.SelectGlobalVariableByVariableName(dbh, "VERSION")
89
90         lib.Logger.Println("- mysql_version: '" + mysql_version + "'")
91         if !re_valid_version.MatchString(mysql_version) {
92                 err := errors.New(lib.MyName() + " does not work with MySQL version " + mysql_version)
93                 return err
94         }
95         lib.Logger.Println("- MySQL version is valid, continuing")
96
97         return nil
98 }
99
100 func main() {
101         flag.Parse()
102
103         // clean me up
104         if *cpuprofile != "" {
105                 f, err := os.Create(*cpuprofile)
106                 if err != nil {
107                         log.Fatal(err)
108                 }
109                 pprof.StartCPUProfile(f)
110                 defer pprof.StopCPUProfile()
111         }
112
113         if *flag_debug {
114                 lib.Logger.EnableLogging(true)
115         }
116         if *flag_version {
117                 fmt.Println(lib.MyName() + " version " + version.Version())
118                 return
119         }
120         if *flag_help {
121                 usage()
122                 return
123         }
124
125         lib.Logger.Println("Starting " + lib.MyName())
126
127         dbh := get_db_handle()
128         if err := validate_mysql_version(dbh); err != nil {
129                 log.Fatal(err)
130         }
131
132         var state state.State
133         interval := time.Second
134         sigChan := make(chan os.Signal, 1)
135         done := make(chan struct{})
136         defer close(done)
137         termboxChan := new_tb_chan()
138
139         signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
140
141         ticker := time.NewTicker(interval) // generate a periodic signal
142
143         state.Setup(dbh)
144
145         finished := false
146         for !finished {
147                 select {
148                 case <-done:
149                         fmt.Println("exiting")
150                         finished = true
151                 case sig := <-sigChan:
152                         fmt.Println("Caught a signal", sig)
153                         done <- struct{}{}
154                 case <-ticker.C:
155                         state.Collect()
156                         state.Display()
157                 case event := <-termboxChan:
158                         // switch on event type
159                         switch event.Type {
160                         case termbox.EventKey: // actions depend on key
161                                 switch event.Key {
162                                 case termbox.KeyCtrlZ, termbox.KeyCtrlC, termbox.KeyEsc:
163                                         finished = true
164                                 case termbox.KeyTab: // tab - change display modes
165                                         state.DisplayNext()
166                                         state.Display()
167                                 }
168                                 switch event.Ch {
169                                 case '-': // decrease the interval if > 1
170                                         if interval > time.Second {
171                                                 ticker.Stop()
172                                                 interval -= time.Second
173                                                 ticker = time.NewTicker(interval)
174                                         }
175                                 case '+': // increase interval by creating a new ticker
176                                         ticker.Stop()
177                                         interval += time.Second
178                                         ticker = time.NewTicker(interval)
179                                 case 'h': // help
180                                         state.SetHelp(!state.Help())
181                                 case 'q': // quit
182                                         finished = true
183                                 case 't': // toggle between absolute/relative statistics
184                                         state.SetWantRelativeStats(!state.WantRelativeStats())
185                                         state.Display()
186                                 case 'z': // reset the statistics to now by taking a query of current values
187                                         state.ResetDBStatistics()
188                                         state.Display()
189                                 }
190                         case termbox.EventResize: // set sizes
191                                 state.ScreenSetSize(event.Width, event.Height)
192                                 state.Display()
193                         case termbox.EventError: // quit
194                                 log.Fatalf("Quitting because of termbox error: \n%s\n", event.Err)
195                         }
196                 }
197         }
198         state.Cleanup()
199         ticker.Stop()
200         lib.Logger.Println("Terminating " + lib.MyName())
201 }