diff --git a/config.go b/config.go new file mode 100644 index 0000000..8a2ae7e --- /dev/null +++ b/config.go @@ -0,0 +1,219 @@ +package main + +import ( + "errors" + "fmt" + "os" + "path" + "path/filepath" + "regexp" + "runtime" + "strings" + + lua "github.com/yuin/gopher-lua" +) + +const ( + configFileName = "config.lua" + filesDirName = "files" +) + +type packageConfig struct { + targets map[string]string + disabled map[string]bool + ignore []string + compiledIgnores []*regexp.Regexp +} + +var errTargetDisabled = errors.New("target disabled for this platform") + +func loadConfig(path string) (*packageConfig, error) { + L := lua.NewState() + defer L.Close() + L.OpenLibs() + + if err := L.DoFile(path); err != nil { + return nil, err + } + if L.GetTop() == 0 { + return nil, errors.New("config.lua must return a table") + } + + value := L.Get(-1) + tbl, ok := value.(*lua.LTable) + if !ok { + return nil, errors.New("config.lua must return a table") + } + + targetVal := tbl.RawGetString("target") + targetTbl, ok := targetVal.(*lua.LTable) + if !ok { + return nil, errors.New("config.target must be a table") + } + + targets := make(map[string]string) + disabled := make(map[string]bool) + targetTbl.ForEach(func(k, v lua.LValue) { + ks, ok := k.(lua.LString) + if !ok { + return + } + + if v == lua.LNil || v == lua.LFalse { + disabled[string(ks)] = true + return + } + + vs, ok := v.(lua.LString) + if !ok { + return + } + + targets[string(ks)] = expandHome(string(vs)) + }) + + if len(targets) == 0 && len(disabled) == 0 { + return nil, errors.New("config.target is empty") + } + + ignore, err := parseIgnore(tbl) + if err != nil { + return nil, err + } + compiledIgnores, err := compileIgnorePatterns(ignore) + if err != nil { + return nil, err + } + + return &packageConfig{targets: targets, disabled: disabled, ignore: ignore, compiledIgnores: compiledIgnores}, nil +} + +func selectTarget(cfg *packageConfig) (string, error) { + osKey := runtime.GOOS + if osKey == "darwin" { + osKey = "macos" + } + if cfg.disabled[osKey] { + return "", errTargetDisabled + } + if target, ok := cfg.targets[osKey]; ok { + return expandHome(target), nil + } + if target, ok := cfg.targets["default"]; ok { + return expandHome(target), nil + } + return "", fmt.Errorf("missing target for %s and default", osKey) +} + +func parseIgnore(cfgTbl *lua.LTable) ([]string, error) { + ignoreVal := cfgTbl.RawGetString("ignore") + if ignoreVal == lua.LNil { + return nil, nil + } + + ignoreTbl, ok := ignoreVal.(*lua.LTable) + if !ok { + return nil, errors.New("config.ignore must be an array of strings") + } + + ignore := make([]string, 0, ignoreTbl.Len()) + var parseErr error + ignoreTbl.ForEach(func(k, v lua.LValue) { + if parseErr != nil { + return + } + + if _, ok := k.(lua.LNumber); !ok { + parseErr = errors.New("config.ignore must be an array of strings") + return + } + + s, ok := v.(lua.LString) + if !ok { + parseErr = errors.New("config.ignore must contain only strings") + return + } + + pattern := strings.TrimSpace(string(s)) + if pattern == "" { + parseErr = errors.New("config.ignore cannot contain empty patterns") + return + } + + ignore = append(ignore, pattern) + }) + if parseErr != nil { + return nil, parseErr + } + return ignore, nil +} + +func compileIgnorePatterns(patterns []string) ([]*regexp.Regexp, error) { + compiled := make([]*regexp.Regexp, 0, len(patterns)) + for _, pattern := range patterns { + re, err := globToRegexp(pattern) + if err != nil { + return nil, fmt.Errorf("invalid ignore pattern %q: %w", pattern, err) + } + compiled = append(compiled, re) + } + return compiled, nil +} + +func globToRegexp(pattern string) (*regexp.Regexp, error) { + pattern = strings.ReplaceAll(pattern, "\\", "/") + pattern = strings.TrimPrefix(pattern, "./") + if strings.HasPrefix(pattern, "/") { + pattern = strings.TrimPrefix(pattern, "/") + } + + var b strings.Builder + b.WriteString("^") + for i := 0; i < len(pattern); { + if i+1 < len(pattern) && pattern[i] == '*' && pattern[i+1] == '*' { + b.WriteString(".*") + i += 2 + continue + } + + ch := pattern[i] + switch ch { + case '*': + b.WriteString("[^/]*") + case '?': + b.WriteString("[^/]") + default: + b.WriteString(regexp.QuoteMeta(string(ch))) + } + i++ + } + b.WriteString("$") + + return regexp.Compile(b.String()) +} + +func shouldIgnorePath(rel string, cfg *packageConfig) bool { + if cfg == nil || len(cfg.compiledIgnores) == 0 { + return false + } + + normalized := path.Clean(filepath.ToSlash(rel)) + for _, re := range cfg.compiledIgnores { + if re.MatchString(normalized) { + return true + } + } + + return false +} + +func writeConfig(path, targetRoot string) error { + osKey := "linux" + if runtime.GOOS == "darwin" { + osKey = "macos" + } + + prettyTarget := compressHome(targetRoot) + content := fmt.Sprintf("---@class SigilConfig\n---@field target table\n---@field ignore? string[]\n\n---@type SigilConfig\nlocal config = {\n\ttarget = {\n\t\t%s = %q,\n\t\tdefault = %q,\n\t},\n\tignore = {\n\t\t-- \"**/.DS_Store\",\n\t\t-- \"**/*.tmp\",\n\t\t-- \"cache/**\",\n\t},\n}\n\nreturn config\n", osKey, prettyTarget, prettyTarget) + return os.WriteFile(path, []byte(content), 0o644) +} diff --git a/main.go b/main.go index 7e91494..ef40f4c 100644 --- a/main.go +++ b/main.go @@ -1,34 +1,14 @@ package main import ( - "bufio" "errors" "fmt" - "io" - "io/fs" "os" - "path" "path/filepath" - "regexp" - "runtime" "strings" - - lua "github.com/yuin/gopher-lua" ) -const ( - configFileName = "config.lua" - filesDirName = "files" -) - -type packageConfig struct { - targets map[string]string - disabled map[string]bool - ignore []string - compiledIgnores []*regexp.Regexp -} - -var errTargetDisabled = errors.New("target disabled for this platform") +const version = "0.1.0" func main() { if len(os.Args) < 2 { @@ -38,6 +18,11 @@ func main() { args := os.Args[1:] + if args[0] == "-v" || args[0] == "--version" { + fmt.Println("sigil", version) + os.Exit(0) + } + var err error switch args[0] { case "apply": @@ -73,142 +58,6 @@ func usage() { fmt.Println(" sigil status") } -type packageFlags struct { - dryRun bool -} - -func parsePackageFlags(args []string) (packageFlags, string, error) { - flags := packageFlags{} - var pkg string - for _, arg := range args { - if arg == "--dry-run" { - flags.dryRun = true - continue - } - if strings.HasPrefix(arg, "-") { - return flags, "", fmt.Errorf("unknown flag %q", arg) - } - if pkg != "" { - return flags, "", errors.New("too many arguments") - } - pkg = arg - } - if pkg == "" { - return flags, "", errors.New("missing package") - } - return flags, pkg, nil -} - -func splitPackageSpec(spec string) (string, string, error) { - if spec == "" { - return "", "", errors.New("missing package") - } - - parts := strings.SplitN(spec, ":", 2) - pkg := parts[0] - rel := "" - if len(parts) == 2 { - rel = parts[1] - } - - pkg = strings.Trim(pkg, "/") - rel = strings.TrimPrefix(rel, "/") - - if pkg == "" { - return "", "", errors.New("invalid package") - } - - return pkg, rel, nil -} - -func resolvePackageSpec(spec string) (string, string, error) { - repo, err := repoPath() - if err != nil { - return "", "", err - } - - repoAbs, err := filepath.Abs(repo) - if err != nil { - return "", "", err - } - - spec = expandHome(spec) - if filepath.IsAbs(spec) { - return resolvePathSpec(spec, repoAbs) - } - - if strings.Contains(spec, string(os.PathSeparator)) { - return resolvePathSpec(spec, repoAbs) - } - - clean := filepath.Clean(spec) - if strings.HasPrefix(clean, ".") || strings.HasPrefix(clean, string(os.PathSeparator)) { - return resolvePathSpec(clean, repoAbs) - } - - return splitPackageSpec(spec) -} - -func resolvePathSpec(pathSpec, repoAbs string) (string, string, error) { - absPath, err := filepath.Abs(pathSpec) - if err != nil { - return "", "", err - } - - if strings.HasPrefix(absPath, repoAbs+string(os.PathSeparator)) || absPath == repoAbs { - rel, err := filepath.Rel(repoAbs, absPath) - if err != nil { - return "", "", err - } - parts := strings.Split(rel, string(os.PathSeparator)) - if len(parts) >= 1 { - pkg := parts[0] - if len(parts) >= 2 && parts[1] == filesDirName { - relPath := filepath.Join(parts[2:]...) - return pkg, relPath, nil - } - return pkg, filepath.Join(parts[1:]...), nil - } - } - - entries, err := os.ReadDir(repoAbs) - if err != nil { - return "", "", err - } - - for _, entry := range entries { - if !entry.IsDir() || strings.HasPrefix(entry.Name(), ".") { - continue - } - pkgDir := filepath.Join(repoAbs, entry.Name()) - configPath := filepath.Join(pkgDir, configFileName) - cfg, err := loadConfig(configPath) - if err != nil { - continue - } - targetRoot, err := selectTarget(cfg) - if err != nil { - if errors.Is(err, errTargetDisabled) { - continue - } - continue - } - absTarget, err := filepath.Abs(expandHome(targetRoot)) - if err != nil { - continue - } - if strings.HasPrefix(absPath, absTarget+string(os.PathSeparator)) || absPath == absTarget { - rel, err := filepath.Rel(absTarget, absPath) - if err != nil { - return "", "", err - } - return entry.Name(), rel, nil - } - } - - return "", "", fmt.Errorf("could not resolve %s to a package", pathSpec) -} - func applyCmd(args []string) error { prune := false for _, arg := range args { @@ -332,7 +181,6 @@ func addCmd(args []string) error { defaultPkg = filepath.Base(absPath) } - reader := bufio.NewReader(os.Stdin) var pkgName string matchedPkg, err := findPackageByTarget(repo, defaultTarget) @@ -350,7 +198,7 @@ func addCmd(args []string) error { } if pkgName == "" { - pkgName = promptWithDefault(reader, "Package name", defaultPkg) + pkgName = promptWithDefault("Package name", defaultPkg) if pkgName == "" { return errors.New("package name cannot be empty") } @@ -358,7 +206,7 @@ func addCmd(args []string) error { targetRootInput := defaultTarget if pkgName != matchedPkg || matchedPkg == "" { - targetRootInput = promptWithDefault(reader, "Target path", defaultTarget) + targetRootInput = promptWithDefault("Target path", defaultTarget) } targetRoot, err := filepath.Abs(expandHome(targetRootInput)) if err != nil { @@ -544,144 +392,6 @@ func removeCmd(args []string) error { return nil } -func applyPackage(filesDir, targetRoot string, cfg *packageConfig) error { - if err := ensureDir(targetRoot); err != nil { - return err - } - - return filepath.WalkDir(filesDir, func(path string, entry fs.DirEntry, err error) error { - if err != nil { - return err - } - if path == filesDir { - return nil - } - - rel, err := filepath.Rel(filesDir, path) - if err != nil { - return err - } - - if shouldIgnorePath(rel, cfg) { - if entry.IsDir() { - return filepath.SkipDir - } - return nil - } - - targetPath := filepath.Join(targetRoot, rel) - - if entry.IsDir() { - return ensureDir(targetPath) - } - - srcAbs, err := filepath.Abs(path) - if err != nil { - return err - } - return linkFile(srcAbs, targetPath) - }) -} - -func linkFile(src, dst string) error { - if err := os.MkdirAll(filepath.Dir(dst), 0o755); err != nil { - return err - } - - if info, err := os.Lstat(dst); err == nil { - if info.Mode()&os.ModeSymlink != 0 { - current, err := os.Readlink(dst) - if err != nil { - return err - } - if current == src { - return nil - } - return fmt.Errorf("conflict at %s (points to %s)", dst, current) - } - return fmt.Errorf("conflict at %s (exists and is not a symlink)", dst) - } else if !errors.Is(err, os.ErrNotExist) { - return err - } - - return os.Symlink(src, dst) -} - -func ensureDir(path string) error { - info, err := os.Lstat(path) - if err == nil { - if info.IsDir() { - return nil - } - return fmt.Errorf("%s exists and is not a directory", path) - } - if !errors.Is(err, os.ErrNotExist) { - return err - } - return os.MkdirAll(path, 0o755) -} - -func loadConfig(path string) (*packageConfig, error) { - L := lua.NewState() - defer L.Close() - L.OpenLibs() - - if err := L.DoFile(path); err != nil { - return nil, err - } - if L.GetTop() == 0 { - return nil, errors.New("config.lua must return a table") - } - - value := L.Get(-1) - tbl, ok := value.(*lua.LTable) - if !ok { - return nil, errors.New("config.lua must return a table") - } - - targetVal := tbl.RawGetString("target") - targetTbl, ok := targetVal.(*lua.LTable) - if !ok { - return nil, errors.New("config.target must be a table") - } - - targets := make(map[string]string) - disabled := make(map[string]bool) - targetTbl.ForEach(func(k, v lua.LValue) { - ks, ok := k.(lua.LString) - if !ok { - return - } - - if v == lua.LNil || v == lua.LFalse { - disabled[string(ks)] = true - return - } - - vs, ok := v.(lua.LString) - if !ok { - return - } - - targets[string(ks)] = expandHome(string(vs)) - }) - - if len(targets) == 0 && len(disabled) == 0 { - return nil, errors.New("config.target is empty") - } - - ignore, err := parseIgnore(tbl) - if err != nil { - return nil, err - } - compiledIgnores, err := compileIgnorePatterns(ignore) - if err != nil { - return nil, err - } - - return &packageConfig{targets: targets, disabled: disabled, ignore: ignore, compiledIgnores: compiledIgnores}, nil -} - func statusCmd() error { repo, err := repoPath() if err != nil { @@ -731,605 +441,3 @@ func statusCmd() error { return nil } - -func selectTarget(cfg *packageConfig) (string, error) { - osKey := runtime.GOOS - if osKey == "darwin" { - osKey = "macos" - } - if cfg.disabled[osKey] { - return "", errTargetDisabled - } - if target, ok := cfg.targets[osKey]; ok { - return expandHome(target), nil - } - if target, ok := cfg.targets["default"]; ok { - return expandHome(target), nil - } - return "", fmt.Errorf("missing target for %s and default", osKey) -} - -func parseIgnore(cfgTbl *lua.LTable) ([]string, error) { - ignoreVal := cfgTbl.RawGetString("ignore") - if ignoreVal == lua.LNil { - return nil, nil - } - - ignoreTbl, ok := ignoreVal.(*lua.LTable) - if !ok { - return nil, errors.New("config.ignore must be an array of strings") - } - - ignore := make([]string, 0, ignoreTbl.Len()) - var parseErr error - ignoreTbl.ForEach(func(k, v lua.LValue) { - if parseErr != nil { - return - } - - if _, ok := k.(lua.LNumber); !ok { - parseErr = errors.New("config.ignore must be an array of strings") - return - } - - s, ok := v.(lua.LString) - if !ok { - parseErr = errors.New("config.ignore must contain only strings") - return - } - - pattern := strings.TrimSpace(string(s)) - if pattern == "" { - parseErr = errors.New("config.ignore cannot contain empty patterns") - return - } - - ignore = append(ignore, pattern) - }) - if parseErr != nil { - return nil, parseErr - } - return ignore, nil -} - -func compileIgnorePatterns(patterns []string) ([]*regexp.Regexp, error) { - compiled := make([]*regexp.Regexp, 0, len(patterns)) - for _, pattern := range patterns { - re, err := globToRegexp(pattern) - if err != nil { - return nil, fmt.Errorf("invalid ignore pattern %q: %w", pattern, err) - } - compiled = append(compiled, re) - } - return compiled, nil -} - -func globToRegexp(pattern string) (*regexp.Regexp, error) { - pattern = strings.ReplaceAll(pattern, "\\", "/") - pattern = strings.TrimPrefix(pattern, "./") - if strings.HasPrefix(pattern, "/") { - pattern = strings.TrimPrefix(pattern, "/") - } - - var b strings.Builder - b.WriteString("^") - for i := 0; i < len(pattern); { - if i+1 < len(pattern) && pattern[i] == '*' && pattern[i+1] == '*' { - b.WriteString(".*") - i += 2 - continue - } - - ch := pattern[i] - switch ch { - case '*': - b.WriteString("[^/]*") - case '?': - b.WriteString("[^/]") - default: - b.WriteString(regexp.QuoteMeta(string(ch))) - } - i++ - } - b.WriteString("$") - - return regexp.Compile(b.String()) -} - -func shouldIgnorePath(rel string, cfg *packageConfig) bool { - if cfg == nil || len(cfg.compiledIgnores) == 0 { - return false - } - - normalized := path.Clean(filepath.ToSlash(rel)) - for _, re := range cfg.compiledIgnores { - if re.MatchString(normalized) { - return true - } - } - - return false -} - -func repoPath() (string, error) { - if override := os.Getenv("SIGIL_REPO"); override != "" { - return filepath.Abs(expandHome(override)) - } - return filepath.Abs(expandHome("~/.dotfiles")) -} - -func expandHome(path string) string { - if path == "~" { - home, err := os.UserHomeDir() - if err != nil { - return path - } - return home - } - if strings.HasPrefix(path, "~/") { - home, err := os.UserHomeDir() - if err != nil { - return path - } - return filepath.Join(home, path[2:]) - } - return path -} - -func compressHome(path string) string { - home, err := os.UserHomeDir() - if err != nil { - return path - } - clean := filepath.Clean(path) - homeClean := filepath.Clean(home) - if clean == homeClean { - return "~" - } - if strings.HasPrefix(clean, homeClean+string(os.PathSeparator)) { - rel := strings.TrimPrefix(clean, homeClean+string(os.PathSeparator)) - return filepath.Join("~", rel) - } - return path -} - -func writeConfig(path, targetRoot string) error { - osKey := "linux" - if runtime.GOOS == "darwin" { - osKey = "macos" - } - - prettyTarget := compressHome(targetRoot) - content := fmt.Sprintf("---@class SigilConfig\n---@field target table\n---@field ignore? string[]\n\n---@type SigilConfig\nlocal config = {\n\ttarget = {\n\t\t%s = %q,\n\t\tdefault = %q,\n\t},\n\tignore = {\n\t\t-- \"**/.DS_Store\",\n\t\t-- \"**/*.tmp\",\n\t\t-- \"cache/**\",\n\t},\n}\n\nreturn config\n", osKey, prettyTarget, prettyTarget) - return os.WriteFile(path, []byte(content), 0o644) -} - -func promptWithDefault(reader *bufio.Reader, label, def string) string { - if def != "" { - fmt.Printf("%s [%s]: ", label, def) - } else { - fmt.Printf("%s: ", label) - } - - text, _ := reader.ReadString('\n') - text = strings.TrimSpace(text) - if text == "" { - return def - } - return text -} - -func promptYesNo(message string, def bool) (bool, error) { - reader := bufio.NewReader(os.Stdin) - defLabel := "y/N" - if def { - defLabel = "Y/n" - } - fmt.Printf("%s [%s]: ", message, defLabel) - text, err := reader.ReadString('\n') - if err != nil { - return false, err - } - text = strings.TrimSpace(strings.ToLower(text)) - if text == "" { - return def, nil - } - return text == "y" || text == "yes", nil -} - -func moveDirContents(srcDir, destDir string) error { - entries, err := os.ReadDir(srcDir) - if err != nil { - return err - } - - for _, entry := range entries { - srcPath := filepath.Join(srcDir, entry.Name()) - destPath := filepath.Join(destDir, entry.Name()) - - if _, err := os.Stat(destPath); err == nil { - return fmt.Errorf("destination already exists: %s", destPath) - } else if !errors.Is(err, os.ErrNotExist) { - return err - } - - if err := os.Rename(srcPath, destPath); err != nil { - return err - } - } - - return nil -} - -func findStaleLinks(filesDir, targetRoot string) ([]string, error) { - filesAbs, err := filepath.Abs(filesDir) - if err != nil { - return nil, err - } - - var stale []string - walkErr := filepath.WalkDir(targetRoot, func(path string, entry fs.DirEntry, err error) error { - if err != nil { - return err - } - if entry.IsDir() { - return nil - } - - info, err := os.Lstat(path) - if err != nil { - return err - } - if info.Mode()&os.ModeSymlink == 0 { - return nil - } - - src, err := os.Readlink(path) - if err != nil { - return err - } - if !filepath.IsAbs(src) { - src = filepath.Join(filepath.Dir(path), src) - } - src = filepath.Clean(src) - - if !strings.HasPrefix(src, filesAbs+string(os.PathSeparator)) && src != filesAbs { - return nil - } - - if _, err := os.Stat(src); errors.Is(err, os.ErrNotExist) { - stale = append(stale, path) - } - return nil - }) - - if walkErr != nil { - return nil, walkErr - } - - return stale, nil -} - -func findPackageByTarget(repo, targetRoot string) (string, error) { - repoEntries, err := os.ReadDir(repo) - if err != nil { - return "", err - } - - absTarget, err := filepath.Abs(expandHome(targetRoot)) - if err != nil { - return "", err - } - - for _, entry := range repoEntries { - if !entry.IsDir() || strings.HasPrefix(entry.Name(), ".") { - continue - } - - pkgDir := filepath.Join(repo, entry.Name()) - configPath := filepath.Join(pkgDir, configFileName) - cfg, err := loadConfig(configPath) - if err != nil { - continue - } - target, err := selectTarget(cfg) - if err != nil { - if errors.Is(err, errTargetDisabled) { - continue - } - continue - } - absPkgTarget, err := filepath.Abs(expandHome(target)) - if err != nil { - continue - } - if absPkgTarget == absTarget { - return entry.Name(), nil - } - } - - return "", nil -} - -func removeLinks(paths []string, dryRun bool) error { - for _, path := range paths { - if dryRun { - fmt.Printf("dry-run: remove %s\n", path) - continue - } - if err := os.Remove(path); err != nil { - return err - } - fmt.Printf("removed %s\n", path) - } - return nil -} - -func handleStaleLinks(stales []string) error { - if len(stales) == 0 { - return nil - } - - repo, err := repoPath() - if err != nil { - return err - } - - reader := bufio.NewReader(os.Stdin) - for _, path := range stales { - fmt.Printf("stale: %s\n", path) - - canUnlink, err := staleHasRepoFile(path, repo) - if err != nil { - return err - } - - prompt := "action [p=prune, u=unlink, i=ignore]: " - if !canUnlink { - prompt = "action [p=prune, i=ignore]: " - } - - fmt.Print(prompt) - choice, err := reader.ReadString('\n') - if err != nil { - return err - } - choice = strings.TrimSpace(strings.ToLower(choice)) - if choice == "" || choice == "i" { - continue - } - if choice == "p" { - if err := os.Remove(path); err != nil { - return err - } - fmt.Printf("removed %s\n", path) - continue - } - if choice == "u" && canUnlink { - if err := unlinkStale(path, repo); err != nil { - return err - } - continue - } - fmt.Println("invalid choice; skipping") - } - - return nil -} - -func staleHasRepoFile(targetPath, repo string) (bool, error) { - repoPath, err := repoPathForTarget(targetPath, repo) - if err != nil { - return false, err - } - if repoPath == "" { - return false, nil - } - if _, err := os.Stat(repoPath); errors.Is(err, os.ErrNotExist) { - return false, nil - } else if err != nil { - return false, err - } - return true, nil -} - -func unlinkStale(targetPath, repo string) error { - repoPath, err := repoPathForTarget(targetPath, repo) - if err != nil { - return err - } - if repoPath == "" { - return nil - } - - if err := os.Remove(targetPath); err != nil { - return err - } - - if err := copyFile(repoPath, targetPath); err != nil { - return err - } - - if err := os.Remove(repoPath); err != nil { - return err - } - - fmt.Printf("unlinked %s (removed %s)\n", targetPath, repoPath) - return nil -} - -func repoPathForTarget(targetPath, repo string) (string, error) { - info, err := os.Lstat(targetPath) - if err != nil { - return "", err - } - if info.Mode()&os.ModeSymlink == 0 { - return "", nil - } - - src, err := os.Readlink(targetPath) - if err != nil { - return "", err - } - if !filepath.IsAbs(src) { - src = filepath.Join(filepath.Dir(targetPath), src) - } - src = filepath.Clean(src) - - repoAbs, err := filepath.Abs(repo) - if err != nil { - return "", err - } - - rel, err := filepath.Rel(repoAbs, src) - if err != nil { - return "", err - } - if strings.HasPrefix(rel, "..") { - return "", nil - } - - parts := strings.Split(rel, string(os.PathSeparator)) - if len(parts) < 3 { - return "", nil - } - if parts[1] != filesDirName { - return "", nil - } - - return filepath.Join(repoAbs, rel), nil -} - -func restorePackage(filesDir, targetRoot string, dryRun bool) error { - filesAbs, err := filepath.Abs(filesDir) - if err != nil { - return err - } - - return filepath.WalkDir(filesDir, func(path string, entry fs.DirEntry, err error) error { - if err != nil { - return err - } - if path == filesDir { - return nil - } - if entry.IsDir() { - return nil - } - - rel, err := filepath.Rel(filesDir, path) - if err != nil { - return err - } - targetPath := filepath.Join(targetRoot, rel) - return restoreOne(path, targetPath, filesAbs, dryRun) - }) -} - -func restorePath(filesDir, targetRoot, relPath string, dryRun bool) error { - filesAbs, err := filepath.Abs(filesDir) - if err != nil { - return err - } - - relPath = filepath.Clean(relPath) - if strings.HasPrefix(relPath, "..") || filepath.IsAbs(relPath) { - return fmt.Errorf("invalid relative path %q", relPath) - } - - sourcePath := filepath.Join(filesDir, relPath) - info, err := os.Lstat(sourcePath) - if err != nil { - return err - } - - if info.IsDir() { - return filepath.WalkDir(sourcePath, func(path string, entry fs.DirEntry, err error) error { - if err != nil { - return err - } - if entry.IsDir() { - return nil - } - rel, err := filepath.Rel(filesDir, path) - if err != nil { - return err - } - targetPath := filepath.Join(targetRoot, rel) - return restoreOne(path, targetPath, filesAbs, dryRun) - }) - } - - rel, err := filepath.Rel(filesDir, sourcePath) - if err != nil { - return err - } - return restoreOne(sourcePath, filepath.Join(targetRoot, rel), filesAbs, dryRun) -} - -func restoreOne(sourcePath, targetPath, filesAbs string, dryRun bool) error { - info, err := os.Lstat(targetPath) - if errors.Is(err, os.ErrNotExist) { - return nil - } - if err != nil { - return err - } - if info.Mode()&os.ModeSymlink == 0 { - return nil - } - - src, err := os.Readlink(targetPath) - if err != nil { - return err - } - if !filepath.IsAbs(src) { - src = filepath.Join(filepath.Dir(targetPath), src) - } - src = filepath.Clean(src) - - if !strings.HasPrefix(src, filesAbs+string(os.PathSeparator)) && src != filesAbs { - return nil - } - - if dryRun { - fmt.Printf("dry-run: restore %s\n", targetPath) - return nil - } - - if err := os.Remove(targetPath); err != nil { - return err - } - - fmt.Printf("restored %s\n", targetPath) - return copyFile(sourcePath, targetPath) -} - -func copyFile(src, dst string) error { - if err := os.MkdirAll(filepath.Dir(dst), 0o755); err != nil { - return err - } - - srcFile, err := os.Open(src) - if err != nil { - return err - } - defer srcFile.Close() - - info, err := srcFile.Stat() - if err != nil { - return err - } - - dstFile, err := os.OpenFile(dst, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, info.Mode()) - if err != nil { - return err - } - defer dstFile.Close() - - if _, err := io.Copy(dstFile, srcFile); err != nil { - return err - } - - return nil -} diff --git a/ops.go b/ops.go new file mode 100644 index 0000000..0e41648 --- /dev/null +++ b/ops.go @@ -0,0 +1,402 @@ +package main + +import ( + "errors" + "fmt" + "io" + "io/fs" + "os" + "path/filepath" + "strings" +) + +func applyPackage(filesDir, targetRoot string, cfg *packageConfig) error { + if err := ensureDir(targetRoot); err != nil { + return err + } + + return filepath.WalkDir(filesDir, func(path string, entry fs.DirEntry, err error) error { + if err != nil { + return err + } + if path == filesDir { + return nil + } + + rel, err := filepath.Rel(filesDir, path) + if err != nil { + return err + } + + if shouldIgnorePath(rel, cfg) { + if entry.IsDir() { + return filepath.SkipDir + } + return nil + } + + targetPath := filepath.Join(targetRoot, rel) + + if entry.IsDir() { + return ensureDir(targetPath) + } + + srcAbs, err := filepath.Abs(path) + if err != nil { + return err + } + return linkFile(srcAbs, targetPath) + }) +} + +func linkFile(src, dst string) error { + if err := os.MkdirAll(filepath.Dir(dst), 0o755); err != nil { + return err + } + + if info, err := os.Lstat(dst); err == nil { + if info.Mode()&os.ModeSymlink != 0 { + current, err := os.Readlink(dst) + if err != nil { + return err + } + if current == src { + return nil + } + return fmt.Errorf("conflict at %s (points to %s)", dst, current) + } + return fmt.Errorf("conflict at %s (exists and is not a symlink)", dst) + } else if !errors.Is(err, os.ErrNotExist) { + return err + } + + return os.Symlink(src, dst) +} + +func ensureDir(path string) error { + info, err := os.Lstat(path) + if err == nil { + if info.IsDir() { + return nil + } + return fmt.Errorf("%s exists and is not a directory", path) + } + if !errors.Is(err, os.ErrNotExist) { + return err + } + return os.MkdirAll(path, 0o755) +} + +func findStaleLinks(filesDir, targetRoot string) ([]string, error) { + filesAbs, err := filepath.Abs(filesDir) + if err != nil { + return nil, err + } + + var stale []string + walkErr := filepath.WalkDir(targetRoot, func(path string, entry fs.DirEntry, err error) error { + if err != nil { + return err + } + if entry.IsDir() { + return nil + } + + info, err := os.Lstat(path) + if err != nil { + return err + } + if info.Mode()&os.ModeSymlink == 0 { + return nil + } + + src, err := os.Readlink(path) + if err != nil { + return err + } + if !filepath.IsAbs(src) { + src = filepath.Join(filepath.Dir(path), src) + } + src = filepath.Clean(src) + + if !strings.HasPrefix(src, filesAbs+string(os.PathSeparator)) && src != filesAbs { + return nil + } + + if _, err := os.Stat(src); errors.Is(err, os.ErrNotExist) { + stale = append(stale, path) + } + return nil + }) + + if walkErr != nil { + return nil, walkErr + } + + return stale, nil +} + +func removeLinks(paths []string, dryRun bool) error { + for _, path := range paths { + if dryRun { + fmt.Printf("dry-run: remove %s\n", path) + continue + } + if err := os.Remove(path); err != nil { + return err + } + fmt.Printf("removed %s\n", path) + } + return nil +} + +func handleStaleLinks(stales []string) error { + if len(stales) == 0 { + return nil + } + + repo, err := repoPath() + if err != nil { + return err + } + + reader := newReader() + for _, path := range stales { + fmt.Printf("stale: %s\n", path) + + canUnlink, err := staleHasRepoFile(path, repo) + if err != nil { + return err + } + + prompt := "action [p=prune, u=unlink, i=ignore]: " + if !canUnlink { + prompt = "action [p=prune, i=ignore]: " + } + + fmt.Print(prompt) + choice, err := reader.ReadString('\n') + if err != nil { + return err + } + choice = strings.TrimSpace(strings.ToLower(choice)) + if choice == "" || choice == "i" { + continue + } + if choice == "p" { + if err := os.Remove(path); err != nil { + return err + } + fmt.Printf("removed %s\n", path) + continue + } + if choice == "u" && canUnlink { + if err := unlinkStale(path, repo); err != nil { + return err + } + continue + } + fmt.Println("invalid choice; skipping") + } + + return nil +} + +func staleHasRepoFile(targetPath, repo string) (bool, error) { + repoPath, err := repoPathForTarget(targetPath, repo) + if err != nil { + return false, err + } + if repoPath == "" { + return false, nil + } + if _, err := os.Stat(repoPath); errors.Is(err, os.ErrNotExist) { + return false, nil + } else if err != nil { + return false, err + } + return true, nil +} + +func unlinkStale(targetPath, repo string) error { + repoPath, err := repoPathForTarget(targetPath, repo) + if err != nil { + return err + } + if repoPath == "" { + return nil + } + + if err := os.Remove(targetPath); err != nil { + return err + } + + if err := copyFile(repoPath, targetPath); err != nil { + return err + } + + if err := os.Remove(repoPath); err != nil { + return err + } + + fmt.Printf("unlinked %s (removed %s)\n", targetPath, repoPath) + return nil +} + +func restorePackage(filesDir, targetRoot string, dryRun bool) error { + filesAbs, err := filepath.Abs(filesDir) + if err != nil { + return err + } + + return filepath.WalkDir(filesDir, func(path string, entry fs.DirEntry, err error) error { + if err != nil { + return err + } + if path == filesDir { + return nil + } + if entry.IsDir() { + return nil + } + + rel, err := filepath.Rel(filesDir, path) + if err != nil { + return err + } + targetPath := filepath.Join(targetRoot, rel) + return restoreOne(path, targetPath, filesAbs, dryRun) + }) +} + +func restorePath(filesDir, targetRoot, relPath string, dryRun bool) error { + filesAbs, err := filepath.Abs(filesDir) + if err != nil { + return err + } + + relPath = filepath.Clean(relPath) + if strings.HasPrefix(relPath, "..") || filepath.IsAbs(relPath) { + return fmt.Errorf("invalid relative path %q", relPath) + } + + sourcePath := filepath.Join(filesDir, relPath) + info, err := os.Lstat(sourcePath) + if err != nil { + return err + } + + if info.IsDir() { + return filepath.WalkDir(sourcePath, func(path string, entry fs.DirEntry, err error) error { + if err != nil { + return err + } + if entry.IsDir() { + return nil + } + rel, err := filepath.Rel(filesDir, path) + if err != nil { + return err + } + targetPath := filepath.Join(targetRoot, rel) + return restoreOne(path, targetPath, filesAbs, dryRun) + }) + } + + rel, err := filepath.Rel(filesDir, sourcePath) + if err != nil { + return err + } + return restoreOne(sourcePath, filepath.Join(targetRoot, rel), filesAbs, dryRun) +} + +func restoreOne(sourcePath, targetPath, filesAbs string, dryRun bool) error { + info, err := os.Lstat(targetPath) + if errors.Is(err, os.ErrNotExist) { + return nil + } + if err != nil { + return err + } + if info.Mode()&os.ModeSymlink == 0 { + return nil + } + + src, err := os.Readlink(targetPath) + if err != nil { + return err + } + if !filepath.IsAbs(src) { + src = filepath.Join(filepath.Dir(targetPath), src) + } + src = filepath.Clean(src) + + if !strings.HasPrefix(src, filesAbs+string(os.PathSeparator)) && src != filesAbs { + return nil + } + + if dryRun { + fmt.Printf("dry-run: restore %s\n", targetPath) + return nil + } + + if err := os.Remove(targetPath); err != nil { + return err + } + + fmt.Printf("restored %s\n", targetPath) + return copyFile(sourcePath, targetPath) +} + +func copyFile(src, dst string) error { + if err := os.MkdirAll(filepath.Dir(dst), 0o755); err != nil { + return err + } + + srcFile, err := os.Open(src) + if err != nil { + return err + } + defer srcFile.Close() + + info, err := srcFile.Stat() + if err != nil { + return err + } + + dstFile, err := os.OpenFile(dst, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, info.Mode()) + if err != nil { + return err + } + defer dstFile.Close() + + if _, err := io.Copy(dstFile, srcFile); err != nil { + return err + } + + return nil +} + +func moveDirContents(srcDir, destDir string) error { + entries, err := os.ReadDir(srcDir) + if err != nil { + return err + } + + for _, entry := range entries { + srcPath := filepath.Join(srcDir, entry.Name()) + destPath := filepath.Join(destDir, entry.Name()) + + if _, err := os.Stat(destPath); err == nil { + return fmt.Errorf("destination already exists: %s", destPath) + } else if !errors.Is(err, os.ErrNotExist) { + return err + } + + if err := os.Rename(srcPath, destPath); err != nil { + return err + } + } + + return nil +} diff --git a/path.go b/path.go new file mode 100644 index 0000000..be16957 --- /dev/null +++ b/path.go @@ -0,0 +1,244 @@ +package main + +import ( + "errors" + "fmt" + "os" + "path/filepath" + "strings" +) + +func repoPath() (string, error) { + if override := os.Getenv("SIGIL_REPO"); override != "" { + return filepath.Abs(expandHome(override)) + } + return filepath.Abs(expandHome("~/.dotfiles")) +} + +func expandHome(path string) string { + if path == "~" { + home, err := os.UserHomeDir() + if err != nil { + return path + } + return home + } + if strings.HasPrefix(path, "~/") { + home, err := os.UserHomeDir() + if err != nil { + return path + } + return filepath.Join(home, path[2:]) + } + return path +} + +func compressHome(path string) string { + home, err := os.UserHomeDir() + if err != nil { + return path + } + clean := filepath.Clean(path) + homeClean := filepath.Clean(home) + if clean == homeClean { + return "~" + } + if strings.HasPrefix(clean, homeClean+string(os.PathSeparator)) { + rel := strings.TrimPrefix(clean, homeClean+string(os.PathSeparator)) + return filepath.Join("~", rel) + } + return path +} + +func splitPackageSpec(spec string) (string, string, error) { + if spec == "" { + return "", "", errors.New("missing package") + } + + parts := strings.SplitN(spec, ":", 2) + pkg := parts[0] + rel := "" + if len(parts) == 2 { + rel = parts[1] + } + + pkg = strings.Trim(pkg, "/") + rel = strings.TrimPrefix(rel, "/") + + if pkg == "" { + return "", "", errors.New("invalid package") + } + + return pkg, rel, nil +} + +func resolvePackageSpec(spec string) (string, string, error) { + repo, err := repoPath() + if err != nil { + return "", "", err + } + + repoAbs, err := filepath.Abs(repo) + if err != nil { + return "", "", err + } + + spec = expandHome(spec) + if filepath.IsAbs(spec) { + return resolvePathSpec(spec, repoAbs) + } + + if strings.Contains(spec, string(os.PathSeparator)) { + return resolvePathSpec(spec, repoAbs) + } + + clean := filepath.Clean(spec) + if strings.HasPrefix(clean, ".") || strings.HasPrefix(clean, string(os.PathSeparator)) { + return resolvePathSpec(clean, repoAbs) + } + + return splitPackageSpec(spec) +} + +func resolvePathSpec(pathSpec, repoAbs string) (string, string, error) { + absPath, err := filepath.Abs(pathSpec) + if err != nil { + return "", "", err + } + + if strings.HasPrefix(absPath, repoAbs+string(os.PathSeparator)) || absPath == repoAbs { + rel, err := filepath.Rel(repoAbs, absPath) + if err != nil { + return "", "", err + } + parts := strings.Split(rel, string(os.PathSeparator)) + if len(parts) >= 1 { + pkg := parts[0] + if len(parts) >= 2 && parts[1] == filesDirName { + relPath := filepath.Join(parts[2:]...) + return pkg, relPath, nil + } + return pkg, filepath.Join(parts[1:]...), nil + } + } + + entries, err := os.ReadDir(repoAbs) + if err != nil { + return "", "", err + } + + for _, entry := range entries { + if !entry.IsDir() || strings.HasPrefix(entry.Name(), ".") { + continue + } + pkgDir := filepath.Join(repoAbs, entry.Name()) + configPath := filepath.Join(pkgDir, configFileName) + cfg, err := loadConfig(configPath) + if err != nil { + continue + } + targetRoot, err := selectTarget(cfg) + if err != nil { + if errors.Is(err, errTargetDisabled) { + continue + } + continue + } + absTarget, err := filepath.Abs(expandHome(targetRoot)) + if err != nil { + continue + } + if strings.HasPrefix(absPath, absTarget+string(os.PathSeparator)) || absPath == absTarget { + rel, err := filepath.Rel(absTarget, absPath) + if err != nil { + return "", "", err + } + return entry.Name(), rel, nil + } + } + + return "", "", fmt.Errorf("could not resolve %s to a package", pathSpec) +} + +func findPackageByTarget(repo, targetRoot string) (string, error) { + repoEntries, err := os.ReadDir(repo) + if err != nil { + return "", err + } + + absTarget, err := filepath.Abs(expandHome(targetRoot)) + if err != nil { + return "", err + } + + for _, entry := range repoEntries { + if !entry.IsDir() || strings.HasPrefix(entry.Name(), ".") { + continue + } + + pkgDir := filepath.Join(repo, entry.Name()) + configPath := filepath.Join(pkgDir, configFileName) + cfg, err := loadConfig(configPath) + if err != nil { + continue + } + target, err := selectTarget(cfg) + if err != nil { + if errors.Is(err, errTargetDisabled) { + continue + } + continue + } + absPkgTarget, err := filepath.Abs(expandHome(target)) + if err != nil { + continue + } + if absPkgTarget == absTarget { + return entry.Name(), nil + } + } + + return "", nil +} + +func repoPathForTarget(targetPath, repo string) (string, error) { + info, err := os.Lstat(targetPath) + if err != nil { + return "", err + } + if info.Mode()&os.ModeSymlink == 0 { + return "", nil + } + + src, err := os.Readlink(targetPath) + if err != nil { + return "", err + } + if !filepath.IsAbs(src) { + src = filepath.Join(filepath.Dir(targetPath), src) + } + src = filepath.Clean(src) + + repoAbs, err := filepath.Abs(repo) + if err != nil { + return "", err + } + + rel, err := filepath.Rel(repoAbs, src) + if err != nil { + return "", err + } + if strings.HasPrefix(rel, "..") { + return "", nil + } + + parts := strings.Split(rel, string(os.PathSeparator)) + if len(parts) < 3 { + return "", nil + } + if parts[1] != filesDirName { + return "", nil + } + + return filepath.Join(repoAbs, rel), nil +} diff --git a/sigil b/sigil index 29b8b1d..bf10623 100755 Binary files a/sigil and b/sigil differ diff --git a/util.go b/util.go new file mode 100644 index 0000000..3a8a03f --- /dev/null +++ b/util.go @@ -0,0 +1,78 @@ +package main + +import ( + "bufio" + "errors" + "fmt" + "os" + "strings" +) + +type packageFlags struct { + dryRun bool +} + +func parsePackageFlags(args []string) (packageFlags, string, error) { + flags := packageFlags{} + var pkg string + for _, arg := range args { + if arg == "--dry-run" { + flags.dryRun = true + continue + } + if strings.HasPrefix(arg, "-") { + return flags, "", fmt.Errorf("unknown flag %q", arg) + } + if pkg != "" { + return flags, "", errors.New("too many arguments") + } + pkg = arg + } + if pkg == "" { + return flags, "", errors.New("missing package") + } + return flags, pkg, nil +} + +var stdinReader *bufio.Reader + +func newReader() *bufio.Reader { + if stdinReader == nil { + stdinReader = bufio.NewReader(os.Stdin) + } + return stdinReader +} + +func promptWithDefault(label, def string) string { + reader := newReader() + if def != "" { + fmt.Printf("%s [%s]: ", label, def) + } else { + fmt.Printf("%s: ", label) + } + + text, _ := reader.ReadString('\n') + text = strings.TrimSpace(text) + if text == "" { + return def + } + return text +} + +func promptYesNo(message string, def bool) (bool, error) { + reader := newReader() + defLabel := "y/N" + if def { + defLabel = "Y/n" + } + fmt.Printf("%s [%s]: ", message, defLabel) + text, err := reader.ReadString('\n') + if err != nil { + return false, err + } + text = strings.TrimSpace(strings.ToLower(text)) + if text == "" { + return def, nil + } + return text == "y" || text == "yes", nil +}