diff --git a/lib/cover.go b/lib/cover.go index 053a8c2..be8fad8 100644 --- a/lib/cover.go +++ b/lib/cover.go @@ -68,7 +68,7 @@ func coverHandler(h handler) { } } -func GetCover(e *epubgo.Epub, id string, store *storage.Store) bool { +func GetCover(e *epubgo.Epub, id string, store storage.Store) bool { if coverFromMetadata(e, id, store) { return true } @@ -121,7 +121,7 @@ func GetCover(e *epubgo.Epub, id string, store *storage.Store) bool { return false } -func coverFromMetadata(e *epubgo.Epub, id string, store *storage.Store) bool { +func coverFromMetadata(e *epubgo.Epub, id string, store storage.Store) bool { metaList, _ := e.MetadataAttr("meta") for _, meta := range metaList { if meta["name"] == "cover" { @@ -135,7 +135,7 @@ func coverFromMetadata(e *epubgo.Epub, id string, store *storage.Store) bool { return false } -func searchCommonCoverNames(e *epubgo.Epub, id string, store *storage.Store) bool { +func searchCommonCoverNames(e *epubgo.Epub, id string, store storage.Store) bool { for _, p := range []string{"cover.jpg", "Images/cover.jpg", "images/cover.jpg", "cover.jpeg", "cover1.jpg", "cover1.jpeg"} { img, err := e.OpenFile(p) if err == nil { @@ -146,7 +146,7 @@ func searchCommonCoverNames(e *epubgo.Epub, id string, store *storage.Store) boo return false } -func storeImg(img io.Reader, id string, store *storage.Store) bool { +func storeImg(img io.Reader, id string, store storage.Store) bool { /* open the files */ fBig, err := store.Create(id, coverFile) if err != nil { diff --git a/lib/database/books_test.go b/lib/database/books_test.go index 2c8bc81..c0246a9 100644 --- a/lib/database/books_test.go +++ b/lib/database/books_test.go @@ -10,7 +10,7 @@ var book = map[string]interface{}{ func TestAddBook(t *testing.T) { db := Init(test_host, test_coll) - defer db.del() + defer del(db) tAddBook(t, db) @@ -31,7 +31,7 @@ func TestAddBook(t *testing.T) { func TestActiveBook(t *testing.T) { db := Init(test_host, test_coll) - defer db.del() + defer del(db) tAddBook(t, db) books, _, _ := db.GetNewBooks("", 1, 0) @@ -53,7 +53,7 @@ func TestActiveBook(t *testing.T) { func TestFlag(t *testing.T) { db := Init(test_host, test_coll) - defer db.del() + defer del(db) tAddBook(t, db) id, _ := book["id"].(string) @@ -97,7 +97,7 @@ func TestFlag(t *testing.T) { } } -func tAddBook(t *testing.T, db *DB) { +func tAddBook(t *testing.T, db DB) { err := db.AddBook(book) if err != nil { t.Error("db.AddBook(", book, ") return an error:", err) diff --git a/lib/database/database.go b/lib/database/database.go index d0698e0..600fc85 100644 --- a/lib/database/database.go +++ b/lib/database/database.go @@ -3,22 +3,42 @@ package database import ( log "github.com/cihub/seelog" - "errors" "os" "gopkg.in/mgo.v2" - "gopkg.in/mgo.v2/bson" ) -const ( - visited_coll = "visited" - downloaded_coll = "downloaded" - tags_coll = "tags" -) - -type DB struct { - session *mgo.Session - name string +type DB interface { + Close() + Copy() DB + AddBook(book map[string]interface{}) error + GetBooks(query string, length int, start int) (books []Book, num int, err error) + GetBooksIter() Iter + GetNewBooks(query string, length int, start int) (books []Book, num int, err error) + GetBookId(id string) (Book, error) + DeleteBook(id string) error + UpdateBook(id string, data map[string]interface{}) error + FlagBadQuality(id string, user string) error + ActiveBook(id string) error + IsBookActive(id string) bool + User(name string) *User + AddUser(name string, pass string) error + AddNews(text string) error + GetNews(num int, days int) (news []News, err error) + AddStats(stats interface{}) error + GetVisitedBooks() (books []Book, err error) + UpdateMostVisited() error + GetDownloadedBooks() (books []Book, err error) + UpdateDownloadedBooks() error + GetTags() ([]string, error) + UpdateTags() error + GetVisits(visitType VisitType) ([]Visits, error) + UpdateHourVisits() error + UpdateDayVisits() error + UpdateMonthVisits() error + UpdateHourDownloads() error + UpdateDayDownloads() error + UpdateMonthDownloads() error } type Iter interface { @@ -26,9 +46,9 @@ type Iter interface { Next(interface{}) bool } -func Init(host string, name string) *DB { +func Init(host string, name string) DB { var err error - db := new(DB) + db := new(mgoDB) db.session, err = mgo.Dial(host) if err != nil { log.Critical(err) @@ -39,230 +59,6 @@ func Init(host string, name string) *DB { return db } -func (db *DB) initIndexes() { - dbCopy := db.session.Copy() - booksColl := dbCopy.DB(db.name).C(books_coll) - go indexBooks(booksColl) - statsColl := dbCopy.DB(db.name).C(stats_coll) - go indexStats(statsColl) - newsColl := dbCopy.DB(db.name).C(news_coll) - go indexNews(newsColl) -} - -func (db *DB) Close() { - db.session.Close() -} - -func (db *DB) Copy() *DB { - dbCopy := new(DB) - dbCopy.session = db.session.Copy() - dbCopy.name = db.name - return dbCopy -} - -func (db *DB) AddBook(book map[string]interface{}) error { - booksColl := db.session.DB(db.name).C(books_coll) - return addBook(booksColl, book) -} - -func (db *DB) GetBooks(query string, length int, start int) (books []Book, num int, err error) { - booksColl := db.session.DB(db.name).C(books_coll) - return getBooks(booksColl, query, length, start) -} - -func (db *DB) GetBooksIter() Iter { - booksColl := db.session.DB(db.name).C(books_coll) - return getBooksIter(booksColl) -} - -func (db *DB) GetNewBooks(query string, length int, start int) (books []Book, num int, err error) { - booksColl := db.session.DB(db.name).C(books_coll) - return getNewBooks(booksColl, query, length, start) -} - -func (db *DB) GetBookId(id string) (Book, error) { - booksColl := db.session.DB(db.name).C(books_coll) - return getBookId(booksColl, id) -} - -func (db *DB) DeleteBook(id string) error { - booksColl := db.session.DB(db.name).C(books_coll) - return deleteBook(booksColl, id) -} - -func (db *DB) UpdateBook(id string, data map[string]interface{}) error { - booksColl := db.session.DB(db.name).C(books_coll) - return updateBook(booksColl, id, data) -} - -func (db *DB) FlagBadQuality(id string, user string) error { - booksColl := db.session.DB(db.name).C(books_coll) - return flagBadQuality(booksColl, id, user) -} - -func (db *DB) ActiveBook(id string) error { - booksColl := db.session.DB(db.name).C(books_coll) - return activeBook(booksColl, id) -} - -func (db *DB) IsBookActive(id string) bool { - booksColl := db.session.DB(db.name).C(books_coll) - return isBookActive(booksColl, id) -} - -func (db *DB) User(name string) *User { - userColl := db.session.DB(db.name).C(user_coll) - return getUser(userColl, name) -} - -func (db *DB) AddUser(name string, pass string) error { - userColl := db.session.DB(db.name).C(user_coll) - return addUser(userColl, name, pass) -} - -func (db *DB) AddNews(text string) error { - newsColl := db.session.DB(db.name).C(news_coll) - return addNews(newsColl, text) -} - -func (db *DB) GetNews(num int, days int) (news []News, err error) { - newsColl := db.session.DB(db.name).C(news_coll) - return getNews(newsColl, num, days) -} - -// TODO: split code in files -func (db *DB) AddStats(stats interface{}) error { - statsColl := db.session.DB(db.name).C(stats_coll) - return statsColl.Insert(stats) -} - -/* Get the most visited books - */ -func (db *DB) GetVisitedBooks() (books []Book, err error) { - visitedColl := db.session.DB(db.name).C(visited_coll) - bookId, err := GetBooksVisited(visitedColl) - if err != nil { - return nil, err - } - - books = make([]Book, len(bookId)) - for i, id := range bookId { - booksColl := db.session.DB(db.name).C(books_coll) - booksColl.Find(bson.M{"_id": id}).One(&books[i]) - books[i].Id = bson.ObjectId(books[i].Id).Hex() - } - return -} - -func (db *DB) UpdateMostVisited() error { - var u dbUpdate - u.src = db.session.DB(db.name).C(stats_coll) - u.dst = db.session.DB(db.name).C(visited_coll) - return u.UpdateMostBooks("book") -} - -/* Get the most downloaded books - */ -func (db *DB) GetDownloadedBooks() (books []Book, err error) { - downloadedColl := db.session.DB(db.name).C(downloaded_coll) - bookId, err := GetBooksVisited(downloadedColl) - if err != nil { - return nil, err - } - - books = make([]Book, len(bookId)) - for i, id := range bookId { - booksColl := db.session.DB(db.name).C(books_coll) - booksColl.Find(bson.M{"_id": id}).One(&books[i]) - books[i].Id = bson.ObjectId(books[i].Id).Hex() - } - return -} - -func (db *DB) UpdateDownloadedBooks() error { - var u dbUpdate - u.src = db.session.DB(db.name).C(stats_coll) - u.dst = db.session.DB(db.name).C(downloaded_coll) - return u.UpdateMostBooks("download") -} - -func (db *DB) GetTags() ([]string, error) { - tagsColl := db.session.DB(db.name).C(tags_coll) - return GetTags(tagsColl) -} - -func (db *DB) UpdateTags() error { - var u dbUpdate - u.src = db.session.DB(db.name).C(books_coll) - u.dst = db.session.DB(db.name).C(tags_coll) - return u.UpdateTags() -} - -func (db *DB) GetVisits(visitType VisitType) ([]Visits, error) { - var coll *mgo.Collection - switch visitType { - case Hourly_visits: - coll = db.session.DB(db.name).C(hourly_visits_coll) - case Daily_visits: - coll = db.session.DB(db.name).C(daily_visits_coll) - case Monthly_visits: - coll = db.session.DB(db.name).C(monthly_visits_coll) - case Hourly_downloads: - coll = db.session.DB(db.name).C(hourly_downloads_coll) - case Daily_downloads: - coll = db.session.DB(db.name).C(daily_downloads_coll) - case Monthly_downloads: - coll = db.session.DB(db.name).C(monthly_downloads_coll) - default: - return nil, errors.New("Not valid VisitType") - } - return GetVisits(coll) -} - -func (db *DB) UpdateHourVisits() error { - var u dbUpdate - u.src = db.session.DB(db.name).C(stats_coll) - u.dst = db.session.DB(db.name).C(hourly_visits_coll) - return u.UpdateHourVisits(false) -} - -func (db *DB) UpdateDayVisits() error { - var u dbUpdate - u.src = db.session.DB(db.name).C(stats_coll) - u.dst = db.session.DB(db.name).C(daily_visits_coll) - return u.UpdateDayVisits(false) -} - -func (db *DB) UpdateMonthVisits() error { - var u dbUpdate - u.src = db.session.DB(db.name).C(stats_coll) - u.dst = db.session.DB(db.name).C(monthly_visits_coll) - return u.UpdateMonthVisits(false) -} - -func (db *DB) UpdateHourDownloads() error { - var u dbUpdate - u.src = db.session.DB(db.name).C(stats_coll) - u.dst = db.session.DB(db.name).C(hourly_downloads_coll) - return u.UpdateHourVisits(true) -} - -func (db *DB) UpdateDayDownloads() error { - var u dbUpdate - u.src = db.session.DB(db.name).C(stats_coll) - u.dst = db.session.DB(db.name).C(daily_downloads_coll) - return u.UpdateDayVisits(true) -} - -func (db *DB) UpdateMonthDownloads() error { - var u dbUpdate - u.src = db.session.DB(db.name).C(stats_coll) - u.dst = db.session.DB(db.name).C(monthly_downloads_coll) - return u.UpdateMonthVisits(true) -} - -// function defined for the tests -func (db *DB) del() { - defer db.Close() - db.session.DB(db.name).DropDatabase() +func RO(db DB) DB { + return &roDB{db} } diff --git a/lib/database/database_test.go b/lib/database/database_test.go index f62d51e..cfc2a43 100644 --- a/lib/database/database_test.go +++ b/lib/database/database_test.go @@ -1,6 +1,10 @@ package database -import "testing" +import ( + "testing" + + mgo "gopkg.in/mgo.v2" +) const ( test_coll = "test_trantor" @@ -12,29 +16,9 @@ func TestInit(t *testing.T) { defer db.Close() } -func TestCopy(t *testing.T) { - db := Init(test_host, test_coll) - defer db.del() - - db2 := db.Copy() - - if db.name != db2.name { - t.Errorf("Names don't match") - } - names1, err := db.session.DatabaseNames() - if err != nil { - t.Errorf("Error on db1: ", err) - } - names2, err := db2.session.DatabaseNames() - if err != nil { - t.Errorf("Error on db1: ", err) - } - if len(names1) != len(names2) { - t.Errorf("len(names) don't match") - } - for i, _ := range names1 { - if names1[i] != names2[i] { - t.Errorf("Names don't match") - } - } +func del(db DB) { + db.Close() + session, _ := mgo.Dial(test_host) + defer session.Close() + session.DB(test_coll).DropDatabase() } diff --git a/lib/database/mgo.go b/lib/database/mgo.go new file mode 100644 index 0000000..25fe6f1 --- /dev/null +++ b/lib/database/mgo.go @@ -0,0 +1,241 @@ +package database + +import ( + "errors" + + mgo "gopkg.in/mgo.v2" + "gopkg.in/mgo.v2/bson" +) + +const ( + visited_coll = "visited" + downloaded_coll = "downloaded" + tags_coll = "tags" +) + +type mgoDB struct { + session *mgo.Session + name string +} + +func (db *mgoDB) initIndexes() { + dbCopy := db.session.Copy() + booksColl := dbCopy.DB(db.name).C(books_coll) + go indexBooks(booksColl) + statsColl := dbCopy.DB(db.name).C(stats_coll) + go indexStats(statsColl) + newsColl := dbCopy.DB(db.name).C(news_coll) + go indexNews(newsColl) +} + +func (db *mgoDB) Close() { + db.session.Close() +} + +func (db *mgoDB) Copy() DB { + dbCopy := new(mgoDB) + dbCopy.session = db.session.Copy() + dbCopy.name = db.name + return dbCopy +} + +func (db *mgoDB) AddBook(book map[string]interface{}) error { + booksColl := db.session.DB(db.name).C(books_coll) + return addBook(booksColl, book) +} + +func (db *mgoDB) GetBooks(query string, length int, start int) (books []Book, num int, err error) { + booksColl := db.session.DB(db.name).C(books_coll) + return getBooks(booksColl, query, length, start) +} + +func (db *mgoDB) GetBooksIter() Iter { + booksColl := db.session.DB(db.name).C(books_coll) + return getBooksIter(booksColl) +} + +func (db *mgoDB) GetNewBooks(query string, length int, start int) (books []Book, num int, err error) { + booksColl := db.session.DB(db.name).C(books_coll) + return getNewBooks(booksColl, query, length, start) +} + +func (db *mgoDB) GetBookId(id string) (Book, error) { + booksColl := db.session.DB(db.name).C(books_coll) + return getBookId(booksColl, id) +} + +func (db *mgoDB) DeleteBook(id string) error { + booksColl := db.session.DB(db.name).C(books_coll) + return deleteBook(booksColl, id) +} + +func (db *mgoDB) UpdateBook(id string, data map[string]interface{}) error { + booksColl := db.session.DB(db.name).C(books_coll) + return updateBook(booksColl, id, data) +} + +func (db *mgoDB) FlagBadQuality(id string, user string) error { + booksColl := db.session.DB(db.name).C(books_coll) + return flagBadQuality(booksColl, id, user) +} + +func (db *mgoDB) ActiveBook(id string) error { + booksColl := db.session.DB(db.name).C(books_coll) + return activeBook(booksColl, id) +} + +func (db *mgoDB) IsBookActive(id string) bool { + booksColl := db.session.DB(db.name).C(books_coll) + return isBookActive(booksColl, id) +} + +func (db *mgoDB) User(name string) *User { + userColl := db.session.DB(db.name).C(user_coll) + return getUser(userColl, name) +} + +func (db *mgoDB) AddUser(name string, pass string) error { + userColl := db.session.DB(db.name).C(user_coll) + return addUser(userColl, name, pass) +} + +func (db *mgoDB) AddNews(text string) error { + newsColl := db.session.DB(db.name).C(news_coll) + return addNews(newsColl, text) +} + +func (db *mgoDB) GetNews(num int, days int) (news []News, err error) { + newsColl := db.session.DB(db.name).C(news_coll) + return getNews(newsColl, num, days) +} + +// TODO: split code in files +func (db *mgoDB) AddStats(stats interface{}) error { + statsColl := db.session.DB(db.name).C(stats_coll) + return statsColl.Insert(stats) +} + +/* Get the most visited books + */ +func (db *mgoDB) GetVisitedBooks() (books []Book, err error) { + visitedColl := db.session.DB(db.name).C(visited_coll) + bookId, err := GetBooksVisited(visitedColl) + if err != nil { + return nil, err + } + + books = make([]Book, len(bookId)) + for i, id := range bookId { + booksColl := db.session.DB(db.name).C(books_coll) + booksColl.Find(bson.M{"_id": id}).One(&books[i]) + books[i].Id = bson.ObjectId(books[i].Id).Hex() + } + return +} + +func (db *mgoDB) UpdateMostVisited() error { + var u dbUpdate + u.src = db.session.DB(db.name).C(stats_coll) + u.dst = db.session.DB(db.name).C(visited_coll) + return u.UpdateMostBooks("book") +} + +/* Get the most downloaded books + */ +func (db *mgoDB) GetDownloadedBooks() (books []Book, err error) { + downloadedColl := db.session.DB(db.name).C(downloaded_coll) + bookId, err := GetBooksVisited(downloadedColl) + if err != nil { + return nil, err + } + + books = make([]Book, len(bookId)) + for i, id := range bookId { + booksColl := db.session.DB(db.name).C(books_coll) + booksColl.Find(bson.M{"_id": id}).One(&books[i]) + books[i].Id = bson.ObjectId(books[i].Id).Hex() + } + return +} + +func (db *mgoDB) UpdateDownloadedBooks() error { + var u dbUpdate + u.src = db.session.DB(db.name).C(stats_coll) + u.dst = db.session.DB(db.name).C(downloaded_coll) + return u.UpdateMostBooks("download") +} + +func (db *mgoDB) GetTags() ([]string, error) { + tagsColl := db.session.DB(db.name).C(tags_coll) + return GetTags(tagsColl) +} + +func (db *mgoDB) UpdateTags() error { + var u dbUpdate + u.src = db.session.DB(db.name).C(books_coll) + u.dst = db.session.DB(db.name).C(tags_coll) + return u.UpdateTags() +} + +func (db *mgoDB) GetVisits(visitType VisitType) ([]Visits, error) { + var coll *mgo.Collection + switch visitType { + case Hourly_visits: + coll = db.session.DB(db.name).C(hourly_visits_coll) + case Daily_visits: + coll = db.session.DB(db.name).C(daily_visits_coll) + case Monthly_visits: + coll = db.session.DB(db.name).C(monthly_visits_coll) + case Hourly_downloads: + coll = db.session.DB(db.name).C(hourly_downloads_coll) + case Daily_downloads: + coll = db.session.DB(db.name).C(daily_downloads_coll) + case Monthly_downloads: + coll = db.session.DB(db.name).C(monthly_downloads_coll) + default: + return nil, errors.New("Not valid VisitType") + } + return GetVisits(coll) +} + +func (db *mgoDB) UpdateHourVisits() error { + var u dbUpdate + u.src = db.session.DB(db.name).C(stats_coll) + u.dst = db.session.DB(db.name).C(hourly_visits_coll) + return u.UpdateHourVisits(false) +} + +func (db *mgoDB) UpdateDayVisits() error { + var u dbUpdate + u.src = db.session.DB(db.name).C(stats_coll) + u.dst = db.session.DB(db.name).C(daily_visits_coll) + return u.UpdateDayVisits(false) +} + +func (db *mgoDB) UpdateMonthVisits() error { + var u dbUpdate + u.src = db.session.DB(db.name).C(stats_coll) + u.dst = db.session.DB(db.name).C(monthly_visits_coll) + return u.UpdateMonthVisits(false) +} + +func (db *mgoDB) UpdateHourDownloads() error { + var u dbUpdate + u.src = db.session.DB(db.name).C(stats_coll) + u.dst = db.session.DB(db.name).C(hourly_downloads_coll) + return u.UpdateHourVisits(true) +} + +func (db *mgoDB) UpdateDayDownloads() error { + var u dbUpdate + u.src = db.session.DB(db.name).C(stats_coll) + u.dst = db.session.DB(db.name).C(daily_downloads_coll) + return u.UpdateDayVisits(true) +} + +func (db *mgoDB) UpdateMonthDownloads() error { + var u dbUpdate + u.src = db.session.DB(db.name).C(stats_coll) + u.dst = db.session.DB(db.name).C(monthly_downloads_coll) + return u.UpdateMonthVisits(true) +} diff --git a/lib/database/news_test.go b/lib/database/news_test.go index 0c91034..ad74388 100644 --- a/lib/database/news_test.go +++ b/lib/database/news_test.go @@ -6,7 +6,7 @@ func TestNews(t *testing.T) { const text = "Some news text" db := Init(test_host, test_coll) - defer db.del() + defer del(db) err := db.AddNews(text) if err != nil { diff --git a/lib/database/ro.go b/lib/database/ro.go new file mode 100644 index 0000000..c728717 --- /dev/null +++ b/lib/database/ro.go @@ -0,0 +1,129 @@ +package database + +import ( + "errors" +) + +type roDB struct { + db DB +} + +func (db *roDB) Close() { + db.db.Close() +} + +func (db *roDB) Copy() DB { + return &roDB{db.db.Copy()} +} + +func (db *roDB) AddBook(book map[string]interface{}) error { + return errors.New("RO database") +} + +func (db *roDB) GetBooks(query string, length int, start int) (books []Book, num int, err error) { + return db.db.GetBooks(query, length, start) +} + +func (db *roDB) GetBooksIter() Iter { + return db.db.GetBooksIter() +} + +func (db *roDB) GetNewBooks(query string, length int, start int) (books []Book, num int, err error) { + return db.db.GetNewBooks(query, length, start) +} + +func (db *roDB) GetBookId(id string) (Book, error) { + return db.db.GetBookId(id) +} + +func (db *roDB) DeleteBook(id string) error { + return errors.New("RO database") +} + +func (db *roDB) UpdateBook(id string, data map[string]interface{}) error { + return errors.New("RO database") +} + +func (db *roDB) FlagBadQuality(id string, user string) error { + return errors.New("RO database") +} + +func (db *roDB) ActiveBook(id string) error { + return errors.New("RO database") +} + +func (db *roDB) IsBookActive(id string) bool { + return db.db.IsBookActive(id) +} + +func (db *roDB) User(name string) *User { + return db.db.User(name) +} + +func (db *roDB) AddUser(name string, pass string) error { + return errors.New("RO database") +} + +func (db *roDB) AddNews(text string) error { + return errors.New("RO database") +} + +func (db *roDB) GetNews(num int, days int) (news []News, err error) { + return db.db.GetNews(num, days) +} + +func (db *roDB) AddStats(stats interface{}) error { + return nil +} + +func (db *roDB) GetVisitedBooks() (books []Book, err error) { + return db.db.GetVisitedBooks() +} + +func (db *roDB) UpdateMostVisited() error { + return errors.New("RO database") +} + +func (db *roDB) GetDownloadedBooks() (books []Book, err error) { + return db.db.GetDownloadedBooks() +} + +func (db *roDB) UpdateDownloadedBooks() error { + return errors.New("RO database") +} + +func (db *roDB) GetTags() ([]string, error) { + return db.db.GetTags() +} + +func (db *roDB) UpdateTags() error { + return errors.New("RO database") +} + +func (db *roDB) GetVisits(visitType VisitType) ([]Visits, error) { + return db.db.GetVisits(visitType) +} + +func (db *roDB) UpdateHourVisits() error { + return errors.New("RO database") +} + +func (db *roDB) UpdateDayVisits() error { + return errors.New("RO database") +} + +func (db *roDB) UpdateMonthVisits() error { + return errors.New("RO database") +} + +func (db *roDB) UpdateHourDownloads() error { + return errors.New("RO database") +} + +func (db *roDB) UpdateDayDownloads() error { + return errors.New("RO database") +} + +func (db *roDB) UpdateMonthDownloads() error { + return errors.New("RO database") +} diff --git a/lib/database/users_test.go b/lib/database/users_test.go index 79ca366..cea542c 100644 --- a/lib/database/users_test.go +++ b/lib/database/users_test.go @@ -8,7 +8,7 @@ const ( func TestUserEmpty(t *testing.T) { db := Init(test_host, test_coll) - defer db.del() + defer del(db) if db.User("").Valid("") { t.Errorf("user.Valid() with an empty password return true") @@ -17,7 +17,7 @@ func TestUserEmpty(t *testing.T) { func TestAddUser(t *testing.T) { db := Init(test_host, test_coll) - defer db.del() + defer del(db) tAddUser(t, db) if !db.User(name).Valid(pass) { @@ -27,7 +27,7 @@ func TestAddUser(t *testing.T) { func TestEmptyUsername(t *testing.T) { db := Init(test_host, test_coll) - defer db.del() + defer del(db) tAddUser(t, db) if db.User("").Valid(pass) { @@ -35,7 +35,7 @@ func TestEmptyUsername(t *testing.T) { } } -func tAddUser(t *testing.T, db *DB) { +func tAddUser(t *testing.T, db DB) { err := db.AddUser(name, pass) if err != nil { t.Errorf("db.Adduser(", name, ", ", pass, ") return an error: ", err) diff --git a/lib/news.go b/lib/news.go index ba2973b..6ef6a80 100644 --- a/lib/news.go +++ b/lib/news.go @@ -60,7 +60,7 @@ func postNewsHandler(h handler) { http.Redirect(h.w, h.r, "/news/", http.StatusFound) } -func getNews(num int, days int, db *database.DB) []newsEntry { +func getNews(num int, days int, db database.DB) []newsEntry { dbnews, _ := db.GetNews(num, days) news := make([]newsEntry, len(dbnews)) for i, n := range dbnews { diff --git a/lib/session.go b/lib/session.go index 802ed71..31cd08d 100644 --- a/lib/session.go +++ b/lib/session.go @@ -23,7 +23,7 @@ type Session struct { S *sessions.Session } -func GetSession(r *http.Request, db *database.DB) (s *Session) { +func GetSession(r *http.Request, db database.DB) (s *Session) { s = new(Session) var err error s.S, err = sesStore.Get(r, "session") diff --git a/lib/stats.go b/lib/stats.go index f4fa359..b1b7af5 100644 --- a/lib/stats.go +++ b/lib/stats.go @@ -22,27 +22,30 @@ type handler struct { w http.ResponseWriter r *http.Request sess *Session - db *database.DB - store *storage.Store + db database.DB + store storage.Store template *Template hostname string + ro bool } type StatsGatherer struct { - db *database.DB - store *storage.Store + db database.DB + store storage.Store template *Template hostname string channel chan statsRequest + ro bool } -func InitStats(database *database.DB, store *storage.Store, hostname string, template *Template) *StatsGatherer { +func InitStats(database database.DB, store storage.Store, hostname string, template *Template, ro bool) *StatsGatherer { sg := StatsGatherer{ channel: make(chan statsRequest, statsChanSize), db: database, store: store, template: template, hostname: hostname, + ro: ro, } go sg.worker() @@ -62,6 +65,7 @@ func (sg StatsGatherer) Gather(function func(handler)) func(http.ResponseWriter, w: w, r: r, sess: GetSession(r, db), + ro: sg.ro, } defer h.db.Close() @@ -136,7 +140,7 @@ func monthlyLabel(date time.Time) string { return date.Month().String() } -func getVisits(funcLabel func(time.Time) string, db *database.DB, visitType database.VisitType) []visitData { +func getVisits(funcLabel func(time.Time) string, db database.DB, visitType database.VisitType) []visitData { var visits []visitData visit, err := db.GetVisits(visitType) diff --git a/lib/storage/fs.go b/lib/storage/fs.go new file mode 100644 index 0000000..d0d4ff3 --- /dev/null +++ b/lib/storage/fs.go @@ -0,0 +1,42 @@ +package storage + +import ( + p "path" + + "io" + "os" +) + +type fsStore struct { + path string +} + +func (st *fsStore) Create(id string, name string) (io.WriteCloser, error) { + path := idPath(st.path, id) + err := os.MkdirAll(path, os.ModePerm) + if err != nil { + return nil, err + } + + return os.Create(p.Join(path, name)) +} + +func (st *fsStore) Store(id string, file io.Reader, name string) (size int64, err error) { + dest, err := st.Create(id, name) + if err != nil { + return 0, err + } + defer dest.Close() + + return io.Copy(dest, file) +} + +func (st *fsStore) Get(id string, name string) (File, error) { + path := idPath(st.path, id) + return os.Open(p.Join(path, name)) +} + +func (st *fsStore) Delete(id string) error { + path := idPath(st.path, id) + return os.RemoveAll(path) +} diff --git a/lib/storage/ro.go b/lib/storage/ro.go new file mode 100644 index 0000000..278cf47 --- /dev/null +++ b/lib/storage/ro.go @@ -0,0 +1,26 @@ +package storage + +import ( + "errors" + "io" +) + +type roStore struct { + store Store +} + +func (st *roStore) Create(id string, name string) (io.WriteCloser, error) { + return nil, errors.New("Can't create, RO storage") +} + +func (st *roStore) Store(id string, file io.Reader, name string) (size int64, err error) { + return 0, errors.New("Can't store, RO storage") +} + +func (st *roStore) Get(id string, name string) (File, error) { + return st.store.Get(id, name) +} + +func (st *roStore) Delete(id string) error { + return errors.New("Can't delete, RO storage") +} diff --git a/lib/storage/storage.go b/lib/storage/storage.go index 531d39d..2ccbe95 100644 --- a/lib/storage/storage.go +++ b/lib/storage/storage.go @@ -1,14 +1,15 @@ package storage import ( - p "path" - "io" "os" ) -type Store struct { - path string +type Store interface { + Create(id string, name string) (io.WriteCloser, error) + Store(id string, file io.Reader, name string) (size int64, err error) + Get(id string, name string) (File, error) + Delete(id string) error } type File interface { @@ -17,8 +18,8 @@ type File interface { Stat() (fi os.FileInfo, err error) } -func Init(path string) (*Store, error) { - st := new(Store) +func Init(path string) (Store, error) { + st := new(fsStore) st.path = path _, err := os.Stat(path) @@ -28,36 +29,6 @@ func Init(path string) (*Store, error) { return st, err } -func (st *Store) Create(id string, name string) (io.WriteCloser, error) { - path := idPath(st.path, id) - err := os.MkdirAll(path, os.ModePerm) - if err != nil { - return nil, err - } - - return os.Create(p.Join(path, name)) -} - -func (st *Store) Store(id string, file io.Reader, name string) (size int64, err error) { - dest, err := st.Create(id, name) - if err != nil { - return 0, err - } - defer dest.Close() - - return io.Copy(dest, file) -} - -func (st *Store) Get(id string, name string) (File, error) { - path := idPath(st.path, id) - return os.Open(p.Join(path, name)) -} - -func (st *Store) Delete(id string) error { - path := idPath(st.path, id) - return os.RemoveAll(path) -} - -func (st *Store) del() { - os.RemoveAll(st.path) +func RO(st Store) Store { + return &roStore{st} } diff --git a/lib/storage/storage_test.go b/lib/storage/storage_test.go index c8aa657..ba465a5 100644 --- a/lib/storage/storage_test.go +++ b/lib/storage/storage_test.go @@ -25,11 +25,11 @@ const ( ) func TestInit(t *testing.T) { - st, err := Init(test_path) + _, err := Init(test_path) if err != nil { t.Fatal("An error ocurred initializing the store =>", err) } - defer st.del() + defer del() info, err := os.Stat(test_path) if err != nil { @@ -50,7 +50,7 @@ func TestInit(t *testing.T) { func TestStore(t *testing.T) { st, err := Init(test_path) - defer st.del() + defer del() _, err = st.Store(test_id, strings.NewReader(test_book), "epub") if err != nil { @@ -72,7 +72,7 @@ func TestStore(t *testing.T) { func TestCreate(t *testing.T) { st, err := Init(test_path) - defer st.del() + defer del() f, err := st.Create(test_id, "img") if err != nil { @@ -95,7 +95,7 @@ func TestCreate(t *testing.T) { func TestDelete(t *testing.T) { st, err := Init(test_path) - defer st.del() + defer del() _, err = st.Store(test_id, strings.NewReader(test_book), "epub") if err != nil { @@ -111,3 +111,7 @@ func TestDelete(t *testing.T) { t.Fatal("Retrieve book without error.") } } + +func del() { + os.RemoveAll(test_path) +} diff --git a/lib/tasker.go b/lib/tasker.go index 58e81d6..afb5eaf 100644 --- a/lib/tasker.go +++ b/lib/tasker.go @@ -21,7 +21,7 @@ const ( minutesUpdateLogger = 5 ) -func InitTasks(db *database.DB, loggerConfig string) { +func InitTasks(db database.DB, loggerConfig string) { updateLogger := func() error { return UpdateLogger(loggerConfig) } diff --git a/lib/trantor.go b/lib/trantor.go index c315c91..88f64c3 100644 --- a/lib/trantor.go +++ b/lib/trantor.go @@ -169,7 +169,7 @@ func UpdateLogger(loggerConfig string) error { return log.ReplaceLogger(logger) } -func InitRouter(db *database.DB, sg *StatsGatherer, assetsPath string) { +func InitRouter(db database.DB, sg *StatsGatherer, assetsPath string) { const idPattern = "[0-9a-zA-Z\\-\\_]{16}" r := mux.NewRouter() diff --git a/lib/upload.go b/lib/upload.go index 7948b2b..c9c8b84 100644 --- a/lib/upload.go +++ b/lib/upload.go @@ -19,7 +19,7 @@ const ( uploadChanSize = 100 ) -func InitUpload(database *database.DB, store *storage.Store) { +func InitUpload(database database.DB, store storage.Store) { uploadChannel = make(chan uploadRequest, uploadChanSize) go uploadWorker(database, store) } @@ -31,7 +31,7 @@ type uploadRequest struct { filename string } -func uploadWorker(database *database.DB, store *storage.Store) { +func uploadWorker(database database.DB, store storage.Store) { db := database.Copy() defer db.Close() @@ -40,7 +40,7 @@ func uploadWorker(database *database.DB, store *storage.Store) { } } -func processFile(req uploadRequest, db *database.DB, store *storage.Store) { +func processFile(req uploadRequest, db database.DB, store storage.Store) { defer req.file.Close() epub, err := openMultipartEpub(req.file) @@ -72,6 +72,12 @@ func processFile(req uploadRequest, db *database.DB, store *storage.Store) { } func uploadPostHandler(h handler) { + if h.ro { + h.sess.Notify("Upload failed!", "The library is in Read Only mode, no books can be uploaded", "error") + uploadHandler(h) + return + } + problem := false h.r.ParseMultipartForm(20000000) diff --git a/main.go b/main.go index dbfe980..11d3e32 100644 --- a/main.go +++ b/main.go @@ -21,6 +21,7 @@ func main() { assetsPath = flag.String("assets", ".", "Path of the assets (templates, css, js, img)") hostname = flag.String("hostname", "xfmro77i3lixucja.onion", "Hostname of the website") loggerConfig = flag.String("logger-conf", "logger.xml", "xml configuration of the logger") + ro = flag.Bool("ro", false, "read only mode") ) flag.Parse() @@ -40,8 +41,13 @@ func main() { os.Exit(1) } + if *ro { + store = storage.RO(store) + db = database.RO(db) + } + template := trantor.InitTemplate(*assetsPath) - sg := trantor.InitStats(db, store, *hostname, template) + sg := trantor.InitStats(db, store, *hostname, template, *ro) trantor.InitUpload(db, store) trantor.InitTasks(db, *loggerConfig)