package database import ( "time" log "github.com/cihub/seelog" "bytes" "crypto/rand" "errors" "regexp" "golang.org/x/crypto/scrypt" ) var alphaNumeric = regexp.MustCompile(`^[a-zA-Z0-9_\-\.]+$`).MatchString type User struct { ID int `pg:"type:serial"` Username string `pg:"type:varchar(255),unique"` Password []byte Salt []byte Role string `pg:"type:varchar(255)"` LastLogin time.Time } func (db *pgDB) AddUser(name string, pass string) error { if !validUserName(name) { return errors.New("Invalid user name. Username needs to have at least 3 characters and can only be letters, numbers, '-', '_' and '.'.") } num, err := db.sql.Model(&User{}).Where("lower(username) = lower(?)", name).Count() if err != nil { log.Error("Error on database checking user ", name, ": ", err) return errors.New("An error happen on the database") } if num != 0 { return errors.New("User name already exist") } hpass, salt, err := hashPass(pass) if err != nil { log.Error("Error hashing password: ", err) return errors.New("An error happen storing the password") } return db.AddRawUser(name, hpass, salt, "") } func (db *pgDB) SetAdminUser(name string, pass string) error { if !validAdminUserName(name) { return errors.New("Invalid admin user name. Username needs to have at least 3 characters and can only be letters, numbers, '-', '_' and '.'.") } num, err := db.sql.Model(&User{}).Where("lower(username) = lower(?)", name).Count() if err != nil { log.Error("Error on database checking user ", name, ": ", err) return errors.New("An error happen on the database") } if num != 0 { err := db.SetRole(name, "admin") if err != nil { log.Error("Error updating admin user ", name, ": ", err) return errors.New("Error updating admin user") } return db.SetPassword(name, pass) } hpass, salt, err := hashPass(pass) if err != nil { log.Error("Error hashing password: ", err) return errors.New("An error happen storing the password") } return db.AddRawUser(name, hpass, salt, "admin") } func (db *pgDB) AddRawUser(name string, hpass []byte, salt []byte, role string) error { u := User{ Username: name, Password: hpass, Salt: salt, Role: role, } _, err := db.sql.Model(&u).Insert() return err } func (db *pgDB) GetRole(name string) (string, error) { var u User err := db.sql.Model(&u).Where("lower(username) = lower(?)", name).Select() return u.Role, err } func (db *pgDB) SetRole(name, role string) error { _, err := db.sql.Model(&User{}). Set("role = ?", role). Where("lower(username) = lower(?)", name). Update() return err } func (db *pgDB) ValidPassword(name string, pass string) bool { var u User err := db.sql.Model(&u).Where("lower(username) = lower(?)", name).Select() if err != nil { return false } hash, err := calculateHash(pass, u.Salt) if err != nil { return false } if bytes.Compare(u.Password, hash) != 0 { return false } _, err = db.sql.Model(&User{}). Set("last_login = CURRENT_TIMESTAMP"). Where("id = ?", u.ID). Update() if err != nil { log.Error("Error updating last login for ", u.Username, ": ", err) } return true } func (db *pgDB) SetPassword(name string, pass string) error { hash, salt, err := hashPass(pass) if err != nil { return err } _, err = db.sql.Model(&User{}). Set("password = ?, salt = ?", hash, salt). Where("lower(username) = lower(?)", name). Update() return err } func (db *pgDB) ListUsers() ([]User, error) { var users []User err := db.sql.Model(&users).Select() return users, err } func (db *pgDB) getUser(name string) (User, error) { var user User err := db.sql.Model(&user). Where("lower(username) = lower(?)", name). Select() return user, err } func validAdminUserName(name string) bool { if len(name) < 3 { return false } if !alphaNumeric(name) { return false } return true } func validUserName(name string) bool { if !validAdminUserName(name) { return false } switch name { case "", "admin", "webmaster", "postmaster", "info", "root", "news", "trantor", "librarian", "library", "imperial": return false default: return true } } func hashPass(pass string) (hash []byte, salt []byte, err error) { salt, err = genSalt() if err != nil { return } hash, err = calculateHash(pass, salt) return } func genSalt() ([]byte, error) { const saltLen = 64 b := make([]byte, saltLen) _, err := rand.Read(b) return b, err } func calculateHash(pass string, salt []byte) ([]byte, error) { const ( N = 16384 r = 8 p = 1 keyLen = 32 ) bpass := []byte(pass) return scrypt.Key(bpass, salt, N, r, p, keyLen) }