move var connector to top of main()
[pstop.git] / main.go
diff --git a/main.go b/main.go
index 4825729..711cfff 100644 (file)
--- a/main.go
+++ b/main.go
@@ -1,71 +1,35 @@
-// Top like progream which collects information from MySQL's
+// pstop - Top like progream which collects information from MySQL's
 // performance_schema database.
 package main
 
 import (
-       "database/sql"
-       "errors"
        "flag"
        "fmt"
        "log"
        "os"
-       "os/signal"
-       "regexp"
        "runtime/pprof"
-       "syscall"
-       "time"
 
        _ "github.com/go-sql-driver/mysql"
-       "github.com/nsf/termbox-go"
 
-       "github.com/sjmudd/mysql_defaults_file"
+       "github.com/sjmudd/pstop/app"
+       "github.com/sjmudd/pstop/connector"
        "github.com/sjmudd/pstop/lib"
-       "github.com/sjmudd/pstop/state"
        "github.com/sjmudd/pstop/version"
 )
 
-const (
-       sql_driver = "mysql"
-       db         = "performance_schema"
-)
-
 var (
-       flag_version = flag.Bool("version", false, "Show the version of "+lib.MyName())
-       flag_debug   = flag.Bool("debug", false, "Enabling debug logging")
-       flag_help    = flag.Bool("help", false, "Provide some help for "+lib.MyName())
-       cpuprofile   = flag.String("cpuprofile", "", "write cpu profile to file")
-
-       re_valid_version = regexp.MustCompile(`^(5\.[67]\.|10\.[01])`)
+       cpuprofile         = flag.String("cpuprofile", "", "write cpu profile to file")
+       flag_debug         = flag.Bool("debug", false, "Enabling debug logging")
+       flag_defaults_file = flag.String("defaults-file", "", "Provide a defaults-file to use to connect to MySQL")
+       flag_help          = flag.Bool("help", false, "Provide some help for "+lib.MyName())
+       flag_host          = flag.String("host", "", "Provide the hostname of the MySQL to connect to")
+       flag_password      = flag.String("password", "", "Provide the password when connecting to the MySQL server")
+       flag_port          = flag.Int("port", 0, "Provide the port number of the MySQL to connect to (default: 3306)") /* deliberately 0 here, defaults to 3306 elsewhere */
+       flag_socket        = flag.String("socket", "", "Provide the path to the local MySQL server to connect to")
+       flag_user          = flag.String("user", "", "Provide the username to connect with to MySQL (default: $USER)")
+       flag_version       = flag.Bool("version", false, "Show the version of "+lib.MyName())
 )
 
-func get_db_handle() *sql.DB {
-       var err error
-       var dbh *sql.DB
-       lib.Logger.Println("get_db_handle() connecting to database")
-
-       dbh, err = mysql_defaults_file.OpenUsingDefaultsFile(sql_driver, "", "performance_schema")
-       if err != nil {
-               log.Fatal(err)
-       }
-       if err = dbh.Ping(); err != nil {
-               log.Fatal(err)
-       }
-
-       return dbh
-}
-
-// make chan for termbox events and run a poller to send events to the channel
-// - return the channel
-func new_tb_chan() chan termbox.Event {
-       termboxChan := make(chan termbox.Event)
-       go func() {
-               for {
-                       termboxChan <- termbox.PollEvent()
-               }
-       }()
-       return termboxChan
-}
-
 func usage() {
        fmt.Println(lib.MyName() + " - " + lib.Copyright())
        fmt.Println("")
@@ -75,29 +39,20 @@ func usage() {
        fmt.Println("Usage: " + lib.MyName() + " <options>")
        fmt.Println("")
        fmt.Println("Options:")
-       fmt.Println("-help      show this help message")
-       fmt.Println("-version   show the version")
-}
-
-// pstop requires MySQL 5.6+ or MariaDB 10.0+. Check the version
-// rather than giving an error message if the requires P_S tables can't
-// be found.
-func validate_mysql_version(dbh *sql.DB) error {
-       lib.Logger.Println("validate_mysql_version()")
-
-       _, mysql_version := lib.SelectGlobalVariableByVariableName(dbh, "VERSION")
-
-       lib.Logger.Println("- mysql_version: '" + mysql_version + "'")
-       if !re_valid_version.MatchString(mysql_version) {
-               err := errors.New(lib.MyName() + " does not work with MySQL version " + mysql_version)
-               return err
-       }
-       lib.Logger.Println("- MySQL version is valid, continuing")
-
-       return nil
+       fmt.Println("--defaults-file=/path/to/defaults.file   Connect to MySQL using given defaults-file")
+       fmt.Println("--help                                   Show this help message")
+       fmt.Println("--version                                Show the version")
+       fmt.Println("--host=<hostname>                        MySQL host to connect to")
+       fmt.Println("--password=<password>                    Password to use when connecting")
+       fmt.Println("--port=<port>                            MySQL port to connect to")
+       fmt.Println("--socket=<path>                          MySQL path of the socket to connect to")
+       fmt.Println("--user=<user>                            User to connect with")
 }
 
 func main() {
+       var connector connector.Connector
+       var defaults_file string = ""
+
        flag.Parse()
 
        // clean me up
@@ -124,78 +79,48 @@ func main() {
 
        lib.Logger.Println("Starting " + lib.MyName())
 
-       dbh := get_db_handle()
-       if err := validate_mysql_version(dbh); err != nil {
-               log.Fatal(err)
-       }
-
-       var state state.State
-       interval := time.Second
-       sigChan := make(chan os.Signal, 1)
-       done := make(chan struct{})
-       defer close(done)
-       termboxChan := new_tb_chan()
-
-       signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
-
-       ticker := time.NewTicker(interval) // generate a periodic signal
-
-       state.Setup(dbh)
-
-       finished := false
-       for !finished {
-               select {
-               case <-done:
-                       fmt.Println("exiting")
-                       finished = true
-               case sig := <-sigChan:
-                       fmt.Println("Caught a signal", sig)
-                       done <- struct{}{}
-               case <-ticker.C:
-                       state.Collect()
-                       state.Display()
-               case event := <-termboxChan:
-                       // switch on event type
-                       switch event.Type {
-                       case termbox.EventKey: // actions depend on key
-                               switch event.Key {
-                               case termbox.KeyCtrlZ, termbox.KeyCtrlC, termbox.KeyEsc:
-                                       finished = true
-                               case termbox.KeyTab: // tab - change display modes
-                                       state.DisplayNext()
-                                       state.Display()
-                               }
-                               switch event.Ch {
-                               case '-': // decrease the interval if > 1
-                                       if interval > time.Second {
-                                               ticker.Stop()
-                                               interval -= time.Second
-                                               ticker = time.NewTicker(interval)
-                                       }
-                               case '+': // increase interval by creating a new ticker
-                                       ticker.Stop()
-                                       interval += time.Second
-                                       ticker = time.NewTicker(interval)
-                               case 'h': // help
-                                       state.SetHelp(!state.Help())
-                               case 'q': // quit
-                                       finished = true
-                               case 't': // toggle between absolute/relative statistics
-                                       state.SetWantRelativeStats(!state.WantRelativeStats())
-                                       state.Display()
-                               case 'z': // reset the statistics to now by taking a query of current values
-                                       state.ResetDBStatistics()
-                                       state.Display()
-                               }
-                       case termbox.EventResize: // set sizes
-                               state.ScreenSetSize(event.Width, event.Height)
-                               state.Display()
-                       case termbox.EventError: // quit
-                               log.Fatalf("Quitting because of termbox error: \n%s\n", event.Err)
+       if *flag_host != "" || *flag_socket != "" {
+               lib.Logger.Println("--host= or --socket= defined")
+               var components = make(map[string]string)
+               if *flag_host != "" && *flag_socket != "" {
+                       fmt.Println(lib.MyName() + ": Do not specify --host and --socket together")
+                       os.Exit(1)
+               }
+               if *flag_host != "" {
+                       components["host"] = *flag_host
+               }
+               if *flag_port != 0 {
+                       if *flag_socket == "" {
+                               components["port"] = fmt.Sprintf("%d", *flag_port)
+                       } else {
+                               fmt.Println(lib.MyName() + ": Do not specify --socket and --port together")
+                               os.Exit(1)
                        }
                }
+               if *flag_socket != "" {
+                       components["socket"] = *flag_socket
+               }
+               if *flag_user != "" {
+                       components["user"] = *flag_user
+               }
+               if *flag_password != "" {
+                       components["password"] = *flag_password
+               }
+               connector.ConnectByComponents(components)
+       } else {
+               if flag_defaults_file != nil && *flag_defaults_file != "" {
+                       lib.Logger.Println("--defaults-file defined")
+                       defaults_file = *flag_defaults_file
+               } else {
+                       lib.Logger.Println("connecting by implicit defaults file")
+               }
+               connector.ConnectByDefaultsFile(defaults_file)
        }
-       state.Cleanup()
-       ticker.Stop()
+
+       var app app.App
+
+       app.Setup(connector.Handle())
+       app.Run()
+       app.Cleanup()
        lib.Logger.Println("Terminating " + lib.MyName())
 }