package commands import ( "fmt" "encoding/json" "os" "strings" "path/filepath" "datasmith/utils" "gopkg.in/yaml.v2" ) type Field struct { Name string `json:"name"` Type string `json:"type"` PrimaryKey bool `json:"primary_key,omitempty"` ForeignKey string `json:"foreign_key,omitempty"` Unique bool `json:"unique,omitempty"` NotNull bool `json:"not_null,omitempty"` AutoIncrement bool `json:"auto_increment,omitempty"` DefaultValue interface{} `json:"default_value,omitempty"` } type TableModel struct { Fields []Field `json:"fields"` } type DatasmithConfig struct { Name string `yaml:"name"` Version string `yaml:"version"` CreatedAt string `yaml:"created_at"` DbType string `yaml:"database_type"` Tables map[string]TableModel `yaml:"tables"` } func Add(tableName string, model string) { utils.PromptIfEmpty(&tableName, "Enter the name of the table: ") projectDir := "." // Assuming current directory is the project directory config, err := getDbTypeFromConfig(projectDir) if err != nil { fmt.Printf("Error loading config: %v\n", err) return } if _, exists := config.Tables[tableName]; exists { fmt.Printf("Table '%s' already exists in the project\n", tableName) return } var tableModel TableModel if model != "" { if err := json.Unmarshal([]byte(model), &tableModel); err != nil { fmt.Printf("Error parsing JSON model: %v\n", err) return } if err := validateModel(tableModel); err != nil { fmt.Printf("Invalid model: %v\n", err) return } } else { tableModel = promptForTableModel() } // TODO: Do Stuff: /* - Generate Test Data for the new table if wanted - Add Table to import-sql.sh - Add Description to the DBML file - Add Description and mermaid and DBML to the README.md file - Add Test to gitlab-ci.yml - Add to CHANGELOG.md - Bump version in datasmith.yaml */ slug := utils.Slugify(tableName) // Update datasmith.yaml config and save the new table definition // Initialize Tables map if it's nil if config.Tables == nil { config.Tables = make(map[string]TableModel) } config.Tables[slug] = tableModel if err := saveConfig(projectDir, config); err != nil { fmt.Printf("Error saving config: %v\n", err) } // Create the table file under sql directory createTableFile(projectDir, tableName, config.DbType, tableModel) fmt.Printf("Added new table '%s' to the project\n", slug) } func getDbTypeFromConfig(projectDir string) (DatasmithConfig, error) { configFilePath := filepath.Join(projectDir, "datasmith.yaml") file, err := os.Open(configFilePath) if err != nil { return DatasmithConfig{}, fmt.Errorf("error opening config file: %v", err) } defer file.Close() var config DatasmithConfig decoder := yaml.NewDecoder(file) if err := decoder.Decode(&config); err != nil { return DatasmithConfig{}, fmt.Errorf("error decoding config file: %v", err) } // Initialize Tables map if it's nil if config.Tables == nil { config.Tables = make(map[string]TableModel) } fmt.Printf("Loaded config for %s:%s\n", config.Name, config.Version) return config, nil } func saveConfig(projectDir string, config DatasmithConfig) error { configFilePath := filepath.Join(projectDir, "datasmith.yaml") file, err := os.Create(configFilePath) if err != nil { return fmt.Errorf("error creating config file: %v", err) } defer file.Close() encoder := yaml.NewEncoder(file) if err := encoder.Encode(&config); err != nil { return fmt.Errorf("error encoding config file: %v", err) } return nil } func createTableFile(projectDir, tableName, dbType string, tableModel TableModel) { slug := utils.Slugify(tableName) sqlFilePath := filepath.Join(projectDir, "sql", fmt.Sprintf("%s.sql", slug)) file, err := os.Create(sqlFilePath) if err != nil { fmt.Printf("Error creating SQL file: %v\n", err) return } defer file.Close() if err := generateSQL(file, tableName, dbType, tableModel); err != nil { fmt.Printf("Error generating SQL: %v\n", err) return } fmt.Printf("Created file: %s\n", sqlFilePath) } func generateSQL(file *os.File, tableName, dbType string, tableModel TableModel) error { switch dbType { case "mysql": return generateMySQLSQL(file, tableName, tableModel) case "postgres": return generatePostgreSQLSQL(file, tableName, tableModel) default: return fmt.Errorf("unsupported database type: %s", dbType) } } func generateMySQLSQL(file *os.File, tableName string, tableModel TableModel) error { var sql strings.Builder sql.WriteString(fmt.Sprintf("CREATE TABLE %s (\n", tableName)) for i, field := range tableModel.Fields { sql.WriteString(fmt.Sprintf(" %s %s", field.Name, field.Type)) if field.PrimaryKey { sql.WriteString(" PRIMARY KEY") } if field.Unique { sql.WriteString(" UNIQUE") } if field.NotNull { sql.WriteString(" NOT NULL") } if field.AutoIncrement { sql.WriteString(" AUTO_INCREMENT") } if field.DefaultValue != nil { sql.WriteString(fmt.Sprintf(" DEFAULT '%v'", field.DefaultValue)) } if field.ForeignKey != "" { sql.WriteString(fmt.Sprintf(",\n FOREIGN KEY (%s) REFERENCES %s", field.Name, field.ForeignKey)) } if i < len(tableModel.Fields)-1 { sql.WriteString(",\n") } } sql.WriteString("\n);\n") _, err := file.WriteString(sql.String()) return err } func generatePostgreSQLSQL(file *os.File, tableName string, tableModel TableModel) error { var sql strings.Builder sql.WriteString(fmt.Sprintf("CREATE TABLE %s (\n", tableName)) for i, field := range tableModel.Fields { sql.WriteString(fmt.Sprintf(" %s %s", field.Name, field.Type)) if field.PrimaryKey { sql.WriteString(" PRIMARY KEY") } if field.Unique { sql.WriteString(" UNIQUE") } if field.NotNull { sql.WriteString(" NOT NULL") } if field.AutoIncrement { sql.WriteString(" SERIAL") } if field.DefaultValue != nil { sql.WriteString(fmt.Sprintf(" DEFAULT '%v'", field.DefaultValue)) } if field.ForeignKey != "" { sql.WriteString(fmt.Sprintf(",\n FOREIGN KEY (%s) REFERENCES %s", field.Name, field.ForeignKey)) } if i < len(tableModel.Fields)-1 { sql.WriteString(",\n") } } sql.WriteString("\n);\n") _, err := file.WriteString(sql.String()) return err } func validateModel(model TableModel) error { fieldNames := make(map[string]bool) for _, field := range model.Fields { // Check if field name is empty if field.Name == "" { return fmt.Errorf("field name cannot be empty") } // Check for duplicate field names if _, exists := fieldNames[field.Name]; exists { return fmt.Errorf("duplicate field name: %s", field.Name) } fieldNames[field.Name] = true // Check if field type is valid if !isValidType(field.Type) { return fmt.Errorf("invalid field type: %s", field.Type) } // Check if primary key is properly set if field.PrimaryKey { if field.ForeignKey != "" { return fmt.Errorf("field %s cannot be both a primary key and a foreign key", field.Name) } if field.AutoIncrement && field.Type != "INT" { return fmt.Errorf("auto increment can only be used with INT type for field %s", field.Name) } } // Check for incompatible settings if field.AutoIncrement && !field.PrimaryKey { return fmt.Errorf("auto increment can only be used with a primary key for field %s", field.Name) } // Add any additional validation rules as needed } return nil } func isValidType(fieldType string) bool { for _, t := range FieldTypes { if t == fieldType { return true } } return false } func promptForTableModel() TableModel { var fields []Field for { var field Field utils.PromptIfEmpty(&field.Name, "Enter field name: ") field.Type = promptForFieldType() field.PrimaryKey = utils.PromptForBool("Is this a primary key? (y/n): ") // TODO: use SelectMenu if !field.PrimaryKey { utils.PromptIfEmpty(&field.ForeignKey, "Enter foreign key (leave empty if not applicable): ") } field.Unique = utils.PromptForBool("Is this field unique? (y/n): ") // TODO: use SelectMenu field.NotNull = utils.PromptForBool("Is this field not null? (y/n): ") // TODO: use SelectMenu field.AutoIncrement = utils.PromptForBool("Is this field auto-increment? (y/n): ") // TODO: use SelectMenu field.DefaultValue = utils.PromptForString("Enter default value (leave empty if not applicable): ") fields = append(fields, field) if !utils.PromptForBool("Do you want to add another field? (y/n): ") { // TODO: use SelectMenu break } } return TableModel{Fields: fields} } func promptForFieldType() string { options := make([]utils.MenuOption, len(FieldTypes)) for i, t := range FieldTypes { options[i] = utils.MenuOption{Display: t, Value: t} } selected, err := utils.SelectMenu(options, "Select field type:") if err != nil { fmt.Printf("Error selecting field type: %v\n", err) return "" } return selected } // Help returns the help information for the add command func AddHelp() string { return "add <tablename> [--model <json_model>]: Add a table to the project." }