package database import ( "io/ioutil" "testing" ) func testDbInit(t *testing.T) (DB, func()) { db, err := Init(Options{ Name: "test_trantor", // TODO: can it be done with a local user? User: "trantor", Password: "trantor", }) if err != nil { t.Fatal("Init() return an error: ", err) } pgdb, _ := db.(*pgDB) buf, err := ioutil.ReadFile("../../createdb.sql") if err != nil { t.Fatal("error reading sql schema: ", err) } schema := string(buf) _, err = pgdb.sql.Exec(schema) if err != nil { t.Fatal("error setting up sql schema: ", err) } cleanFn := func() { entities := []struct { name string query string }{ {"table", "select tablename from pg_tables where schemaname = 'public'"}, {"index", "select indexname from pg_indexes where schemaname = 'public'"}, {"function", `SELECT format('%s(%s)', p.proname, pg_get_function_identity_arguments(p.oid)) FROM pg_catalog.pg_namespace n JOIN pg_catalog.pg_proc p ON p.pronamespace = n.oid WHERE n.nspname = 'public'`}, {"trigger", "select tgname from pg_trigger"}, } for _, entity := range entities { var items []string _, err = pgdb.sql.Query(&items, entity.query) if err != nil { t.Error("get the list of "+entity.name+"return an error: ", err) } for _, item := range items { _, err = pgdb.sql.Exec("drop " + entity.name + " " + item + " cascade") if err != nil { t.Error("drop ", entity.name, " ", item, " return an error: ", err) } } } err = db.Close() if err != nil { t.Error("db.Close() return an error: ", err) } } return db, cleanFn } func TestInit(t *testing.T) { _, dbclose := testDbInit(t) defer dbclose() }