diff --git a/commands/add.go b/commands/add.go index 59eea3168497698cb9ccd50d2d45b03d9b6271b4..c8b5c29b7d8ff086a5159a90dcfcafd5cacb427c 100644 --- a/commands/add.go +++ b/commands/add.go @@ -2,24 +2,69 @@ package commands import ( "fmt" + "encoding/json" + "os" + "strings" + "path/filepath" "datasmith/utils" + "gopkg.in/yaml.v2" ) -func Add(tablename string) { - utils.PromptIfEmpty(&tablename, "Enter the name of the table: ") - addTable(tablename) +type Field struct { + Name string `json:"name"` + Type string `json:"type"` + PrimaryKey bool `json:"primary_key,omitempty"` + ForeignKey string `json:"foreign_key,omitempty"` + Unique bool `json:"unique,omitempty"` + NotNull bool `json:"not_null,omitempty"` + AutoIncrement bool `json:"auto_increment,omitempty"` + DefaultValue interface{} `json:"default_value,omitempty"` } -func addTable(name string) { - slug := utils.Slugify(name) +type TableModel struct { + Fields []Field `json:"fields"` +} + +type DatasmithConfig struct { + Name string `yaml:"name"` + Version string `yaml:"version"` + CreatedAt string `yaml:"created_at"` + DbType string `yaml:"database_type"` + Tables map[string]TableModel `yaml:"tables"` +} + +func Add(tableName string, model string) { + utils.PromptIfEmpty(&tableName, "Enter the name of the table: ") + + projectDir := "." // Assuming current directory is the project directory + config, err := getDbTypeFromConfig(projectDir) + if err != nil { + fmt.Printf("Error loading config: %v\n", err) + return + } - // TODO: Do Stuff: + if _, exists := config.Tables[tableName]; exists { + fmt.Printf("Table '%s' already exists in the project\n", tableName) + return + } + + var tableModel TableModel + if model != "" { + if err := json.Unmarshal([]byte(model), &tableModel); err != nil { + fmt.Printf("Error parsing JSON model: %v\n", err) + return + } + if err := validateModel(tableModel); err != nil { + fmt.Printf("Invalid model: %v\n", err) + return + } + } else { + tableModel = promptForTableModel() + } + + +// TODO: Do Stuff: /* - - Parameter --model (yamL) as file or string - - if nothing, ask for fields (name, type, primary key, foreign key, unique, not null, auto increment, default value) - - Check if the table already exists in the datasmith.yaml file - - Add a new table and columns to the datasmith.yaml file - - Create a new table file in the sql directory (by database type) sql/slug.sql - Generate Test Data for the new table if wanted - Add Table to import-sql.sh - Add Description to the DBML file @@ -30,10 +75,265 @@ func addTable(name string) { */ + slug := utils.Slugify(tableName) + + // Update datasmith.yaml config and save the new table definition + // Initialize Tables map if it's nil + if config.Tables == nil { + config.Tables = make(map[string]TableModel) + } + config.Tables[slug] = tableModel + if err := saveConfig(projectDir, config); err != nil { + fmt.Printf("Error saving config: %v\n", err) + } + + // Create the table file under sql directory + createTableFile(projectDir, tableName, config.DbType, tableModel) + + fmt.Printf("Added new table '%s' to the project\n", slug) } +func getDbTypeFromConfig(projectDir string) (DatasmithConfig, error) { + configFilePath := filepath.Join(projectDir, "datasmith.yaml") + file, err := os.Open(configFilePath) + if err != nil { + return DatasmithConfig{}, fmt.Errorf("error opening config file: %v", err) + } + defer file.Close() + + var config DatasmithConfig + decoder := yaml.NewDecoder(file) + if err := decoder.Decode(&config); err != nil { + return DatasmithConfig{}, fmt.Errorf("error decoding config file: %v", err) + } + + // Initialize Tables map if it's nil + if config.Tables == nil { + config.Tables = make(map[string]TableModel) + } + + fmt.Printf("Loaded config for %s:%s\n", config.Name, config.Version) + + return config, nil +} + +func saveConfig(projectDir string, config DatasmithConfig) error { + configFilePath := filepath.Join(projectDir, "datasmith.yaml") + file, err := os.Create(configFilePath) + if err != nil { + return fmt.Errorf("error creating config file: %v", err) + } + defer file.Close() + + encoder := yaml.NewEncoder(file) + if err := encoder.Encode(&config); err != nil { + return fmt.Errorf("error encoding config file: %v", err) + } + + return nil +} + +func createTableFile(projectDir, tableName, dbType string, tableModel TableModel) { + slug := utils.Slugify(tableName) + sqlFilePath := filepath.Join(projectDir, "sql", fmt.Sprintf("%s.sql", slug)) + + file, err := os.Create(sqlFilePath) + if err != nil { + fmt.Printf("Error creating SQL file: %v\n", err) + return + } + defer file.Close() + + if err := generateSQL(file, tableName, dbType, tableModel); err != nil { + fmt.Printf("Error generating SQL: %v\n", err) + return + } + + fmt.Printf("Created file: %s\n", sqlFilePath) +} + + +func generateSQL(file *os.File, tableName, dbType string, tableModel TableModel) error { + switch dbType { + case "mysql": + return generateMySQLSQL(file, tableName, tableModel) + case "postgres": + return generatePostgreSQLSQL(file, tableName, tableModel) + default: + return fmt.Errorf("unsupported database type: %s", dbType) + } +} + +func generateMySQLSQL(file *os.File, tableName string, tableModel TableModel) error { + var sql strings.Builder + + sql.WriteString(fmt.Sprintf("CREATE TABLE %s (\n", tableName)) + + for i, field := range tableModel.Fields { + sql.WriteString(fmt.Sprintf(" %s %s", field.Name, field.Type)) + + if field.PrimaryKey { + sql.WriteString(" PRIMARY KEY") + } + if field.Unique { + sql.WriteString(" UNIQUE") + } + if field.NotNull { + sql.WriteString(" NOT NULL") + } + if field.AutoIncrement { + sql.WriteString(" AUTO_INCREMENT") + } + if field.DefaultValue != nil { + sql.WriteString(fmt.Sprintf(" DEFAULT '%v'", field.DefaultValue)) + } + if field.ForeignKey != "" { + sql.WriteString(fmt.Sprintf(",\n FOREIGN KEY (%s) REFERENCES %s", field.Name, field.ForeignKey)) + } + + if i < len(tableModel.Fields)-1 { + sql.WriteString(",\n") + } + } + + sql.WriteString("\n);\n") + + _, err := file.WriteString(sql.String()) + return err +} + +func generatePostgreSQLSQL(file *os.File, tableName string, tableModel TableModel) error { + var sql strings.Builder + + sql.WriteString(fmt.Sprintf("CREATE TABLE %s (\n", tableName)) + + for i, field := range tableModel.Fields { + sql.WriteString(fmt.Sprintf(" %s %s", field.Name, field.Type)) + + if field.PrimaryKey { + sql.WriteString(" PRIMARY KEY") + } + if field.Unique { + sql.WriteString(" UNIQUE") + } + if field.NotNull { + sql.WriteString(" NOT NULL") + } + if field.AutoIncrement { + sql.WriteString(" SERIAL") + } + if field.DefaultValue != nil { + sql.WriteString(fmt.Sprintf(" DEFAULT '%v'", field.DefaultValue)) + } + if field.ForeignKey != "" { + sql.WriteString(fmt.Sprintf(",\n FOREIGN KEY (%s) REFERENCES %s", field.Name, field.ForeignKey)) + } + + if i < len(tableModel.Fields)-1 { + sql.WriteString(",\n") + } + } + + sql.WriteString("\n);\n") + + _, err := file.WriteString(sql.String()) + return err +} + +func validateModel(model TableModel) error { + fieldNames := make(map[string]bool) + + for _, field := range model.Fields { + // Check if field name is empty + if field.Name == "" { + return fmt.Errorf("field name cannot be empty") + } + + // Check for duplicate field names + if _, exists := fieldNames[field.Name]; exists { + return fmt.Errorf("duplicate field name: %s", field.Name) + } + fieldNames[field.Name] = true + + // Check if field type is valid + if !isValidType(field.Type) { + return fmt.Errorf("invalid field type: %s", field.Type) + } + + // Check if primary key is properly set + if field.PrimaryKey { + if field.ForeignKey != "" { + return fmt.Errorf("field %s cannot be both a primary key and a foreign key", field.Name) + } + if field.AutoIncrement && field.Type != "INT" { + return fmt.Errorf("auto increment can only be used with INT type for field %s", field.Name) + } + } + + // Check for incompatible settings + if field.AutoIncrement && !field.PrimaryKey { + return fmt.Errorf("auto increment can only be used with a primary key for field %s", field.Name) + } + + // Add any additional validation rules as needed + } + + return nil +} + +func isValidType(fieldType string) bool { + for _, t := range FieldTypes { + if t == fieldType { + return true + } + } + return false +} + +func promptForTableModel() TableModel { + var fields []Field + + for { + var field Field + utils.PromptIfEmpty(&field.Name, "Enter field name: ") + field.Type = promptForFieldType() + + field.PrimaryKey = utils.PromptForBool("Is this a primary key? (y/n): ") + if !field.PrimaryKey { + utils.PromptIfEmpty(&field.ForeignKey, "Enter foreign key (leave empty if not applicable): ") + } + field.Unique = utils.PromptForBool("Is this field unique? (y/n): ") + field.NotNull = utils.PromptForBool("Is this field not null? (y/n): ") + field.AutoIncrement = utils.PromptForBool("Is this field auto-increment? (y/n): ") + field.DefaultValue = utils.PromptForString("Enter default value (leave empty if not applicable): ") + + fields = append(fields, field) + + if !utils.PromptForBool("Do you want to add another field? (y/n): ") { + break + } + } + + return TableModel{Fields: fields} +} + +func promptForFieldType() string { + options := make([]utils.MenuOption, len(FieldTypes)) + for i, t := range FieldTypes { + options[i] = utils.MenuOption{Display: t, Value: t} + } + + selected, err := utils.SelectMenu(options, "Select field type:") + if err != nil { + fmt.Printf("Error selecting field type: %v\n", err) + return "" + } + + return selected +} + // Help returns the help information for the add command func AddHelp() string { - return "add [tablename]: Add a table to the project." + return "add <tablename> [--model <json_model>]: Add a table to the project." } \ No newline at end of file diff --git a/commands/types.go b/commands/types.go new file mode 100644 index 0000000000000000000000000000000000000000..fa0314cb04880a9a115fed4365ca3c92ca017f89 --- /dev/null +++ b/commands/types.go @@ -0,0 +1,14 @@ +package commands + +var FieldTypes = []string{ + "INT", + "VARCHAR(255)", + "TEXT", + "DATE", + "DATETIME", + "BOOLEAN", + "FLOAT", + "DOUBLE", + "DECIMAL(10,2)", + "UUID", +} diff --git a/go.mod b/go.mod index a261b1e60e445e99ee91bb386c75e3cd8608a0e5..cbc7f691ec8f7e4668ecad76f3086ad15387f110 100644 --- a/go.mod +++ b/go.mod @@ -6,4 +6,5 @@ require ( github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e // indirect github.com/manifoldco/promptui v0.9.0 // indirect golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5 // indirect + gopkg.in/yaml.v2 v2.4.0 ) diff --git a/go.sum b/go.sum index 66c56875ee8841f1dc7e16f232012c331f2a856e..3dd0071b7c7823495bfec4c68528f8444265a69d 100644 --- a/go.sum +++ b/go.sum @@ -7,3 +7,6 @@ github.com/manifoldco/promptui v0.9.0/go.mod h1:ka04sppxSGFAtxX0qhlYQjISsg9mR4GW golang.org/x/sys v0.0.0-20181122145206-62eef0e2fa9b/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5 h1:y/woIyUBFbpQGKS0u1aHF/40WUDnek3fPOyD08H5Vng= golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= diff --git a/main.go b/main.go index 2adfc2f83cb7bfc853b62c0750588f8ed2a59df1..eeb075721ed681adbbf0ab3c3bcb93ba756936a9 100644 --- a/main.go +++ b/main.go @@ -32,14 +32,18 @@ func main() { } } commands.InitProject(name, dbType) - fmt.Println("Done. Now run 'ds add table' to add a table to the database.") + fmt.Println("Done. Now run 'ds add' to add a table to the database.") case "add": - if len(os.Args) < 3 { - fmt.Println("Usage: ds add <tablename>") - return + var tableName, model string + args := os.Args[2:] + for i, arg := range args { + if arg == "--model" && i+1 < len(args) { + model = args[i+1] + } else if !strings.HasPrefix(arg, "--") { + tableName = arg + } } - name := os.Args[2] - commands.Add(name) + commands.Add(tableName, model) case "version": commands.Version(version) case "help": diff --git a/templates/datasmith.yaml.tmpl b/templates/datasmith.yaml.tmpl index 904961c16163c782569cab10c113a74a4556cc63..a771f207f12cfa317c94ea11fbfe8757b92cc5d3 100644 --- a/templates/datasmith.yaml.tmpl +++ b/templates/datasmith.yaml.tmpl @@ -1,5 +1,3 @@ -# datasmith.yaml - name: {{ .Name }} version: 0.0.1 created_at: {{ .CreatedAt }} diff --git a/utils/utils.go b/utils/utils.go index fbb43f3b2a1e712c4e4c2316b2e070e32028df94..56279d217f2c99d93da492064b0b886ef49f391c 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -43,6 +43,33 @@ func PromptIfEmpty(value *string, prompt string) { } } + +// PromptForBool prompts the user for a yes/no response and returns a boolean +func PromptForBool(prompt string) bool { + for { + fmt.Print(prompt) + reader := bufio.NewReader(os.Stdin) + input, _ := reader.ReadString('\n') + input = strings.ToLower(strings.TrimSpace(input)) + + if input == "y" || input == "yes" { + return true + } else if input == "n" || input == "no" { + return false + } else { + fmt.Println("Invalid input, please enter 'y' or 'n'") + } + } +} + +// PromptForString prompts the user for a string response and returns it +func PromptForString(prompt string) string { + reader := bufio.NewReader(os.Stdin) + fmt.Print(prompt) + input, _ := reader.ReadString('\n') + return strings.TrimSpace(input) +} + // SelectMenu displays a menu with options and returns the selected value func SelectMenu(options []MenuOption, prompt string) (string, error) { // Handle interrupt signal (Ctrl+C) @@ -75,6 +102,5 @@ func SelectMenu(options []MenuOption, prompt string) (string, error) { return option.Value, nil } } - return "", nil } \ No newline at end of file