package postgresql

import (
	"context"
	"database/sql"
	"fmt"
	"strings"
	"testing"
	"time"

	"github.com/hashicorp/vault/helper/testhelpers/postgresql"
	dbplugin "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
	dbtesting "github.com/hashicorp/vault/sdk/database/dbplugin/v5/testing"
)

func getPostgreSQL(t *testing.T, options map[string]interface{}) (*PostgreSQL, func()) {
	cleanup, connURL := postgresql.PrepareTestContainer(t, "latest")

	connectionDetails := map[string]interface{}{
		"connection_url": connURL,
	}
	for k, v := range options {
		connectionDetails[k] = v
	}

	req := dbplugin.InitializeRequest{
		Config:           connectionDetails,
		VerifyConnection: true,
	}

	db := new()
	dbtesting.AssertInitialize(t, db, req)

	if !db.Initialized {
		t.Fatal("Database should be initialized")
	}
	return db, cleanup
}

func TestPostgreSQL_Initialize(t *testing.T) {
	db, cleanup := getPostgreSQL(t, map[string]interface{}{
		"max_open_connections": 5,
	})
	defer cleanup()

	if err := db.Close(); err != nil {
		t.Fatalf("err: %s", err)
	}
}

func TestPostgreSQL_InitializeWithStringVals(t *testing.T) {
	db, cleanup := getPostgreSQL(t, map[string]interface{}{
		"max_open_connections": "5",
	})
	defer cleanup()

	if err := db.Close(); err != nil {
		t.Fatalf("err: %s", err)
	}
}

func TestPostgreSQL_NewUser(t *testing.T) {
	type testCase struct {
		req            dbplugin.NewUserRequest
		expectErr      bool
		credsAssertion credsAssertion
	}

	tests := map[string]testCase{
		"no creation statements": {
			req: dbplugin.NewUserRequest{
				UsernameConfig: dbplugin.UsernameMetadata{
					DisplayName: "test",
					RoleName:    "test",
				},
				// No statements
				Password:   "somesecurepassword",
				Expiration: time.Now().Add(1 * time.Minute),
			},
			expectErr:      true,
			credsAssertion: assertCredsDoNotExist,
		},
		"admin name": {
			req: dbplugin.NewUserRequest{
				UsernameConfig: dbplugin.UsernameMetadata{
					DisplayName: "test",
					RoleName:    "test",
				},
				Statements: dbplugin.Statements{
					Commands: []string{`
						CREATE ROLE "{{name}}" WITH
						  LOGIN
						  PASSWORD '{{password}}'
						  VALID UNTIL '{{expiration}}';
						GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}";`,
					},
				},
				Password:   "somesecurepassword",
				Expiration: time.Now().Add(1 * time.Minute),
			},
			expectErr:      false,
			credsAssertion: assertCredsExist,
		},
		"admin username": {
			req: dbplugin.NewUserRequest{
				UsernameConfig: dbplugin.UsernameMetadata{
					DisplayName: "test",
					RoleName:    "test",
				},
				Statements: dbplugin.Statements{
					Commands: []string{`
						CREATE ROLE "{{username}}" WITH
						  LOGIN
						  PASSWORD '{{password}}'
						  VALID UNTIL '{{expiration}}';
						GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{username}}";`,
					},
				},
				Password:   "somesecurepassword",
				Expiration: time.Now().Add(1 * time.Minute),
			},
			expectErr:      false,
			credsAssertion: assertCredsExist,
		},
		"read only name": {
			req: dbplugin.NewUserRequest{
				UsernameConfig: dbplugin.UsernameMetadata{
					DisplayName: "test",
					RoleName:    "test",
				},
				Statements: dbplugin.Statements{
					Commands: []string{`
						CREATE ROLE "{{name}}" WITH
						  LOGIN
						  PASSWORD '{{password}}'
						  VALID UNTIL '{{expiration}}';
						GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}";
						GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{name}}";`,
					},
				},
				Password:   "somesecurepassword",
				Expiration: time.Now().Add(1 * time.Minute),
			},
			expectErr:      false,
			credsAssertion: assertCredsExist,
		},
		"read only username": {
			req: dbplugin.NewUserRequest{
				UsernameConfig: dbplugin.UsernameMetadata{
					DisplayName: "test",
					RoleName:    "test",
				},
				Statements: dbplugin.Statements{
					Commands: []string{`
						CREATE ROLE "{{username}}" WITH
						  LOGIN
						  PASSWORD '{{password}}'
						  VALID UNTIL '{{expiration}}';
						GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{username}}";
						GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{username}}";`,
					},
				},
				Password:   "somesecurepassword",
				Expiration: time.Now().Add(1 * time.Minute),
			},
			expectErr:      false,
			credsAssertion: assertCredsExist,
		},
		// https://github.com/hashicorp/vault/issues/6098
		"reproduce GH-6098": {
			req: dbplugin.NewUserRequest{
				UsernameConfig: dbplugin.UsernameMetadata{
					DisplayName: "test",
					RoleName:    "test",
				},
				Statements: dbplugin.Statements{
					Commands: []string{
						// NOTE: "rolname" in the following line is not a typo.
						"DO $$ BEGIN IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname='my_role') THEN CREATE ROLE my_role; END IF; END $$",
					},
				},
				Password:   "somesecurepassword",
				Expiration: time.Now().Add(1 * time.Minute),
			},
			expectErr:      false,
			credsAssertion: assertCredsDoNotExist,
		},
		"reproduce issue with template": {
			req: dbplugin.NewUserRequest{
				UsernameConfig: dbplugin.UsernameMetadata{
					DisplayName: "test",
					RoleName:    "test",
				},
				Statements: dbplugin.Statements{
					Commands: []string{
						`DO $$ BEGIN IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname='my_role') THEN CREATE ROLE "{{username}}"; END IF; END $$`,
					},
				},
				Password:   "somesecurepassword",
				Expiration: time.Now().Add(1 * time.Minute),
			},
			expectErr:      false,
			credsAssertion: assertCredsDoNotExist,
		},
		"large block statements": {
			req: dbplugin.NewUserRequest{
				UsernameConfig: dbplugin.UsernameMetadata{
					DisplayName: "test",
					RoleName:    "test",
				},
				Statements: dbplugin.Statements{
					Commands: newUserLargeBlockStatements,
				},
				Password:   "somesecurepassword",
				Expiration: time.Now().Add(1 * time.Minute),
			},
			expectErr:      false,
			credsAssertion: assertCredsExist,
		},
	}

	// Shared test container for speed - there should not be any overlap between the tests
	db, cleanup := getPostgreSQL(t, nil)
	defer cleanup()

	for name, test := range tests {
		t.Run(name, func(t *testing.T) {
			// Give a timeout just in case the test decides to be problematic
			ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
			defer cancel()

			resp, err := db.NewUser(ctx, test.req)
			if test.expectErr && err == nil {
				t.Fatalf("err expected, got nil")
			}
			if !test.expectErr && err != nil {
				t.Fatalf("no error expected, got: %s", err)
			}

			test.credsAssertion(t, db.ConnectionURL, resp.Username, test.req.Password)

			// Ensure that the role doesn't expire immediately
			time.Sleep(2 * time.Second)

			test.credsAssertion(t, db.ConnectionURL, resp.Username, test.req.Password)
		})
	}
}

func TestUpdateUser_Password(t *testing.T) {
	type testCase struct {
		statements     []string
		expectErr      bool
		credsAssertion credsAssertion
	}

	tests := map[string]testCase{
		"default statements": {
			statements:     nil,
			expectErr:      false,
			credsAssertion: assertCredsExist,
		},
		"explicit default statements": {
			statements:     []string{defaultChangePasswordStatement},
			expectErr:      false,
			credsAssertion: assertCredsExist,
		},
		"name instead of username": {
			statements:     []string{`ALTER ROLE "{{name}}" WITH PASSWORD '{{password}}';`},
			expectErr:      false,
			credsAssertion: assertCredsExist,
		},
		"bad statements": {
			statements:     []string{`asdofyas8uf77asoiajv`},
			expectErr:      true,
			credsAssertion: assertCredsDoNotExist,
		},
	}

	// Shared test container for speed - there should not be any overlap between the tests
	db, cleanup := getPostgreSQL(t, nil)
	defer cleanup()

	for name, test := range tests {
		t.Run(name, func(t *testing.T) {
			initialPass := "myreallysecurepassword"
			createReq := dbplugin.NewUserRequest{
				UsernameConfig: dbplugin.UsernameMetadata{
					DisplayName: "test",
					RoleName:    "test",
				},
				Statements: dbplugin.Statements{
					Commands: []string{createAdminUser},
				},
				Password:   initialPass,
				Expiration: time.Now().Add(2 * time.Second),
			}
			createResp := dbtesting.AssertNewUser(t, db, createReq)

			assertCredsExist(t, db.ConnectionURL, createResp.Username, initialPass)

			newPass := "somenewpassword"
			updateReq := dbplugin.UpdateUserRequest{
				Username: createResp.Username,
				Password: &dbplugin.ChangePassword{
					NewPassword: newPass,
					Statements: dbplugin.Statements{
						Commands: test.statements,
					},
				},
			}

			ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
			defer cancel()
			_, err := db.UpdateUser(ctx, updateReq)
			if test.expectErr && err == nil {
				t.Fatalf("err expected, got nil")
			}
			if !test.expectErr && err != nil {
				t.Fatalf("no error expected, got: %s", err)
			}

			test.credsAssertion(t, db.ConnectionURL, createResp.Username, newPass)
		})
	}

	t.Run("user does not exist", func(t *testing.T) {
		newPass := "somenewpassword"
		updateReq := dbplugin.UpdateUserRequest{
			Username: "missing-user",
			Password: &dbplugin.ChangePassword{
				NewPassword: newPass,
				Statements:  dbplugin.Statements{},
			},
		}

		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		_, err := db.UpdateUser(ctx, updateReq)
		if err == nil {
			t.Fatalf("err expected, got nil")
		}

		assertCredsDoNotExist(t, db.ConnectionURL, updateReq.Username, newPass)
	})
}

func TestUpdateUser_Expiration(t *testing.T) {
	type testCase struct {
		initialExpiration  time.Time
		newExpiration      time.Time
		expectedExpiration time.Time
		statements         []string
		expectErr          bool
	}

	now := time.Now()
	tests := map[string]testCase{
		"no statements": {
			initialExpiration:  now.Add(1 * time.Minute),
			newExpiration:      now.Add(5 * time.Minute),
			expectedExpiration: now.Add(5 * time.Minute),
			statements:         nil,
			expectErr:          false,
		},
		"default statements with name": {
			initialExpiration:  now.Add(1 * time.Minute),
			newExpiration:      now.Add(5 * time.Minute),
			expectedExpiration: now.Add(5 * time.Minute),
			statements:         []string{defaultExpirationStatement},
			expectErr:          false,
		},
		"default statements with username": {
			initialExpiration:  now.Add(1 * time.Minute),
			newExpiration:      now.Add(5 * time.Minute),
			expectedExpiration: now.Add(5 * time.Minute),
			statements:         []string{`ALTER ROLE "{{username}}" VALID UNTIL '{{expiration}}';`},
			expectErr:          false,
		},
		"bad statements": {
			initialExpiration:  now.Add(1 * time.Minute),
			newExpiration:      now.Add(5 * time.Minute),
			expectedExpiration: now.Add(1 * time.Minute),
			statements:         []string{"ladshfouay09sgj"},
			expectErr:          true,
		},
	}

	// Shared test container for speed - there should not be any overlap between the tests
	db, cleanup := getPostgreSQL(t, nil)
	defer cleanup()

	for name, test := range tests {
		t.Run(name, func(t *testing.T) {
			password := "myreallysecurepassword"
			initialExpiration := test.initialExpiration.Truncate(time.Second)
			createReq := dbplugin.NewUserRequest{
				UsernameConfig: dbplugin.UsernameMetadata{
					DisplayName: "test",
					RoleName:    "test",
				},
				Statements: dbplugin.Statements{
					Commands: []string{createAdminUser},
				},
				Password:   password,
				Expiration: initialExpiration,
			}
			createResp := dbtesting.AssertNewUser(t, db, createReq)

			assertCredsExist(t, db.ConnectionURL, createResp.Username, password)

			actualExpiration := getExpiration(t, db, createResp.Username)
			if actualExpiration.IsZero() {
				t.Fatalf("Initial expiration is zero but should be set")
			}
			if !actualExpiration.Equal(initialExpiration) {
				t.Fatalf("Actual expiration: %s Expected expiration: %s", actualExpiration, initialExpiration)
			}

			newExpiration := test.newExpiration.Truncate(time.Second)
			updateReq := dbplugin.UpdateUserRequest{
				Username: createResp.Username,
				Expiration: &dbplugin.ChangeExpiration{
					NewExpiration: newExpiration,
					Statements: dbplugin.Statements{
						Commands: test.statements,
					},
				},
			}

			ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
			defer cancel()
			_, err := db.UpdateUser(ctx, updateReq)
			if test.expectErr && err == nil {
				t.Fatalf("err expected, got nil")
			}
			if !test.expectErr && err != nil {
				t.Fatalf("no error expected, got: %s", err)
			}

			expectedExpiration := test.expectedExpiration.Truncate(time.Second)
			actualExpiration = getExpiration(t, db, createResp.Username)
			if !actualExpiration.Equal(expectedExpiration) {
				t.Fatalf("Actual expiration: %s Expected expiration: %s", actualExpiration, expectedExpiration)
			}
		})
	}
}

func getExpiration(t testing.TB, db *PostgreSQL, username string) time.Time {
	t.Helper()
	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
	defer cancel()

	query := fmt.Sprintf("select valuntil from pg_catalog.pg_user where usename = '%s'", username)
	conn, err := db.getConnection(ctx)
	if err != nil {
		t.Fatalf("Failed to get connection to database: %s", err)
	}

	stmt, err := conn.PrepareContext(ctx, query)
	if err != nil {
		t.Fatalf("Failed to prepare statement: %s", err)
	}
	defer stmt.Close()

	rows, err := stmt.QueryContext(ctx)
	if err != nil {
		t.Fatalf("Failed to execute query to get expiration: %s", err)
	}

	if !rows.Next() {
		return time.Time{} // No expiration
	}
	rawExp := ""
	err = rows.Scan(&rawExp)
	if err != nil {
		t.Fatalf("Unable to get raw expiration: %s", err)
	}
	if rawExp == "" {
		return time.Time{} // No expiration
	}
	exp, err := time.Parse(time.RFC3339, rawExp)
	if err != nil {
		t.Fatalf("Failed to parse expiration %q: %s", rawExp, err)
	}
	return exp
}

func TestDeleteUser(t *testing.T) {
	type testCase struct {
		revokeStmts    []string
		expectErr      bool
		credsAssertion credsAssertion
	}

	tests := map[string]testCase{
		"no statements": {
			revokeStmts: nil,
			expectErr:   false,
			// Wait for a short time before failing because postgres takes a moment to finish deleting the user
			credsAssertion: waitUntilCredsDoNotExist(2 * time.Second),
		},
		"statements with name": {
			revokeStmts: []string{`
				REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM "{{name}}";
				REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM "{{name}}";
				REVOKE USAGE ON SCHEMA public FROM "{{name}}";
		
				DROP ROLE IF EXISTS "{{name}}";`},
			expectErr: false,
			// Wait for a short time before failing because postgres takes a moment to finish deleting the user
			credsAssertion: waitUntilCredsDoNotExist(2 * time.Second),
		},
		"statements with username": {
			revokeStmts: []string{`
				REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM "{{username}}";
				REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM "{{username}}";
				REVOKE USAGE ON SCHEMA public FROM "{{username}}";
		
				DROP ROLE IF EXISTS "{{username}}";`},
			expectErr: false,
			// Wait for a short time before failing because postgres takes a moment to finish deleting the user
			credsAssertion: waitUntilCredsDoNotExist(2 * time.Second),
		},
		"bad statements": {
			revokeStmts: []string{`8a9yhfoiasjff`},
			expectErr:   true,
			// Wait for a short time before checking because postgres takes a moment to finish deleting the user
			credsAssertion: assertCredsExistAfter(100 * time.Millisecond),
		},
	}

	// Shared test container for speed - there should not be any overlap between the tests
	db, cleanup := getPostgreSQL(t, nil)
	defer cleanup()

	for name, test := range tests {
		t.Run(name, func(t *testing.T) {
			password := "myreallysecurepassword"
			createReq := dbplugin.NewUserRequest{
				UsernameConfig: dbplugin.UsernameMetadata{
					DisplayName: "test",
					RoleName:    "test",
				},
				Statements: dbplugin.Statements{
					Commands: []string{createAdminUser},
				},
				Password:   password,
				Expiration: time.Now().Add(2 * time.Second),
			}
			createResp := dbtesting.AssertNewUser(t, db, createReq)

			assertCredsExist(t, db.ConnectionURL, createResp.Username, password)

			deleteReq := dbplugin.DeleteUserRequest{
				Username: createResp.Username,
				Statements: dbplugin.Statements{
					Commands: test.revokeStmts,
				},
			}

			ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
			defer cancel()

			_, err := db.DeleteUser(ctx, deleteReq)
			if test.expectErr && err == nil {
				t.Fatalf("err expected, got nil")
			}
			if !test.expectErr && err != nil {
				t.Fatalf("no error expected, got: %s", err)
			}

			test.credsAssertion(t, db.ConnectionURL, createResp.Username, password)
		})
	}
}

type credsAssertion func(t testing.TB, connURL, username, password string)

func assertCredsExist(t testing.TB, connURL, username, password string) {
	t.Helper()
	err := testCredsExist(t, connURL, username, password)
	if err != nil {
		t.Fatalf("user does not exist: %s", err)
	}
}

func assertCredsDoNotExist(t testing.TB, connURL, username, password string) {
	t.Helper()
	err := testCredsExist(t, connURL, username, password)
	if err == nil {
		t.Fatalf("user should not exist but does")
	}
}

func waitUntilCredsDoNotExist(timeout time.Duration) credsAssertion {
	return func(t testing.TB, connURL, username, password string) {
		t.Helper()
		ctx, cancel := context.WithTimeout(context.Background(), timeout)
		defer cancel()

		ticker := time.NewTicker(10 * time.Millisecond)
		defer ticker.Stop()
		for {
			select {
			case <-ctx.Done():
				t.Fatalf("Timed out waiting for user %s to be deleted", username)
			case <-ticker.C:
				err := testCredsExist(t, connURL, username, password)
				if err != nil {
					// Happy path
					return
				}
			}
		}
	}
}

func assertCredsExistAfter(timeout time.Duration) credsAssertion {
	return func(t testing.TB, connURL, username, password string) {
		t.Helper()
		time.Sleep(timeout)
		assertCredsExist(t, connURL, username, password)
	}
}

func testCredsExist(t testing.TB, connURL, username, password string) error {
	t.Helper()
	// Log in with the new creds
	connURL = strings.Replace(connURL, "postgres:secret", fmt.Sprintf("%s:%s", username, password), 1)
	db, err := sql.Open("postgres", connURL)
	if err != nil {
		return err
	}
	defer db.Close()
	return db.Ping()
}

const createAdminUser = `
CREATE ROLE "{{name}}" WITH
  LOGIN
  PASSWORD '{{password}}'
  VALID UNTIL '{{expiration}}';
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}";
`

var newUserLargeBlockStatements = []string{
	`
DO $$
BEGIN
   IF NOT EXISTS (SELECT * FROM pg_catalog.pg_roles WHERE rolname='foo-role') THEN
      CREATE ROLE "foo-role";
      CREATE SCHEMA IF NOT EXISTS foo AUTHORIZATION "foo-role";
      ALTER ROLE "foo-role" SET search_path = foo;
      GRANT TEMPORARY ON DATABASE "postgres" TO "foo-role";
      GRANT ALL PRIVILEGES ON SCHEMA foo TO "foo-role";
      GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA foo TO "foo-role";
      GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA foo TO "foo-role";
      GRANT ALL PRIVILEGES ON ALL FUNCTIONS IN SCHEMA foo TO "foo-role";
   END IF;
END
$$
`,
	`CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}';`,
	`GRANT "foo-role" TO "{{name}}";`,
	`ALTER ROLE "{{name}}" SET search_path = foo;`,
	`GRANT CONNECT ON DATABASE "postgres" TO "{{name}}";`,
}

func TestContainsMultilineStatement(t *testing.T) {
	type testCase struct {
		Input    string
		Expected bool
	}

	testCases := map[string]*testCase{
		"issue 6098 repro": {
			Input:    `DO $$ BEGIN IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname='my_role') THEN CREATE ROLE my_role; END IF; END $$`,
			Expected: true,
		},
		"multiline with template fields": {
			Input:    `DO $$ BEGIN IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname="{{name}}") THEN CREATE ROLE {{name}}; END IF; END $$`,
			Expected: true,
		},
		"docs example": {
			Input: `CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}'; \
        GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}";`,
			Expected: false,
		},
	}

	for tName, tCase := range testCases {
		t.Run(tName, func(t *testing.T) {
			if containsMultilineStatement(tCase.Input) != tCase.Expected {
				t.Fatalf("%q should be %t for multiline input", tCase.Input, tCase.Expected)
			}
		})
	}
}

func TestExtractQuotedStrings(t *testing.T) {
	type testCase struct {
		Input    string
		Expected []string
	}

	testCases := map[string]*testCase{
		"no quotes": {
			Input:    `Five little monkeys jumping on the bed`,
			Expected: []string{},
		},
		"two of both quote types": {
			Input:    `"Five" little 'monkeys' "jumping on" the' 'bed`,
			Expected: []string{`"Five"`, `"jumping on"`, `'monkeys'`, `' '`},
		},
		"one single quote": {
			Input:    `Five little monkeys 'jumping on the bed`,
			Expected: []string{},
		},
		"empty string": {
			Input:    ``,
			Expected: []string{},
		},
		"templated field": {
			Input:    `DO $$ BEGIN IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname="{{name}}") THEN CREATE ROLE {{name}}; END IF; END $$`,
			Expected: []string{`"{{name}}"`},
		},
	}

	for tName, tCase := range testCases {
		t.Run(tName, func(t *testing.T) {
			results, err := extractQuotedStrings(tCase.Input)
			if err != nil {
				t.Fatal(err)
			}
			if len(results) != len(tCase.Expected) {
				t.Fatalf("%s isn't equal to %s", results, tCase.Expected)
			}
			for i := range results {
				if results[i] != tCase.Expected[i] {
					t.Fatalf(`expected %q but received %q`, tCase.Expected, results[i])
				}
			}
		})
	}
}
