diff --git a/common.go b/common.go new file mode 100644 index 0000000..6dfcc03 --- /dev/null +++ b/common.go @@ -0,0 +1,181 @@ +package main + +import ( + "errors" + "fmt" + "math/rand" + "net" + + "github.com/sirupsen/logrus" + "libvirt.org/go/libvirt" + "libvirt.org/go/libvirtxml" + + "mkvm/config" + "mkvm/libvirtx" + "mkvm/volumes/pools" +) + +type domainModifier func(*libvirtxml.Domain) error + +func getLibvirtAndPool() (*libvirt.Connect, pools.StoragePool, error) { + conn, err := libvirtx.New() + if err != nil { + return nil, nil, err + } + defer conn.Close() + + pool, err := pools.GetPool(conn, config.C.StoragePool) + if err != nil { + return nil, nil, err + } + + return conn, pool, nil +} + +func getServerListenAddress(conn *libvirt.Connect) (string, error) { + serverInterfaceName := "" + if config.C.Network != "" { + nicSource.Network = &libvirtxml.DomainInterfaceSourceNetwork{Network: config.C.Network} + libvirtnet, err := conn.LookupNetworkByName(config.C.Network) + if err != nil { + logrus.WithField("network", config.C.Network).Error("error finding libvirt network") + return "", err + } + + xmlstr, err := libvirtnet.GetXMLDesc(0) + if err != nil { + logrus.WithField("network", config.C.Network).Error("error getting network xml description") + return "", err + } + + var net libvirtxml.Network + if err := net.Unmarshal(xmlstr); err != nil { + logrus.WithField("network", config.C.Network).Error("error parsing network xml description") + return "", err + } + + serverInterfaceName = net.Bridge.Name + } else if config.C.Bridge != "" { + nicSource.Bridge = &libvirtxml.DomainInterfaceSourceBridge{Bridge: config.C.Bridge} + serverInterfaceName = config.C.Bridge + } else { + return "", errors.New("no network or bridge configured") + } + + serverInterface, err := net.InterfaceByName(serverInterfaceName) + if err != nil { + logrus.Error("error finding local network interface to run server on") + return "", err + } + + serverInterfaceAddrs, err := serverInterface.Addrs() + if err != nil { + logrus.Error("error finding local network interface's IP") + return "", err + } + + if len(serverInterfaceAddrs) == 0 { + return "", fmt.Errorf("bridge interface %s does not have an IP on this machine", serverInterfaceName) + } + + serverBindIP, _, err := net.ParseCIDR(serverInterfaceAddrs[0].String()) + if err != nil { + logrus.WithField("interface", serverInterfaceName).WithField("address", serverInterfaceAddrs[0].String()).Error("error parsing local address") + return "", err + } + + port := rand.Intn(65535-1025) + 1025 + return fmt.Sprintf("%s:%d", serverBindIP, port), nil +} + +func createDomain(conn *libvirt.Connect, pool pools.StoragePool, name string, modifiers ...domainModifier) error { + interfaces := []libvirtxml.DomainInterface{ + { + Model: &libvirtxml.DomainInterfaceModel{Type: "virtio"}, + Source: &nicSource, + }, + } + + // Create the domain + domainXML := &libvirtxml.Domain{ + Type: "kvm", + Name: name, + Memory: &libvirtxml.DomainMemory{Value: uint(argMemoryMB), Unit: "MiB"}, + VCPU: &libvirtxml.DomainVCPU{Value: uint(argCPUs)}, + OS: &libvirtxml.DomainOS{ + Type: &libvirtxml.DomainOSType{Arch: "x86_64", Type: "hvm"}, + BootDevices: []libvirtxml.DomainBootDevice{{Dev: "hd"}}, + }, + Features: &libvirtxml.DomainFeatureList{ + ACPI: &libvirtxml.DomainFeature{}, + APIC: &libvirtxml.DomainFeatureAPIC{}, + VMPort: &libvirtxml.DomainFeatureState{State: "off"}, + }, + CPU: &libvirtxml.DomainCPU{Mode: "host-model"}, + Devices: &libvirtxml.DomainDeviceList{ + Emulator: "/usr/bin/kvm", + Disks: []libvirtxml.DomainDisk{pool.GetDomainDiskXML(name)}, + Channels: []libvirtxml.DomainChannel{ + { + Source: &libvirtxml.DomainChardevSource{ + UNIX: &libvirtxml.DomainChardevSourceUNIX{Path: "/var/lib/libvirt/qemu/f16x86_64.agent", Mode: "bind"}, + }, + Target: &libvirtxml.DomainChannelTarget{ + VirtIO: &libvirtxml.DomainChannelTargetVirtIO{Name: "org.qemu.guest_agent.0"}, + }, + }, + }, + Consoles: []libvirtxml.DomainConsole{{Target: &libvirtxml.DomainConsoleTarget{}}}, + Serials: []libvirtxml.DomainSerial{{}}, + Interfaces: interfaces, + }, + } + + for _, modifier := range modifiers { + if err := modifier(domainXML); err != nil { + return err + } + } + + domainXMLString, err := domainXML.Marshal() + if err != nil { + return err + } + + logrus.Debug("defining domain from xml") + domain, err := conn.DomainDefineXML(domainXMLString) + if err != nil { + return fmt.Errorf("error defining domain from xml description: %v", err) + } + + logrus.Debug("booting domain") + err = domain.Create() + if err != nil { + return fmt.Errorf("error creating domain: %v", err) + } + + return nil +} + +func setSMBIOS(smbios map[int]map[string]string) domainModifier { + return func(d *libvirtxml.Domain) error { + qemuArgs := []libvirtxml.DomainQEMUCommandlineArg{} + if d.QEMUCommandline != nil { + qemuArgs = d.QEMUCommandline.Args + } + + for smbiosType, values := range smbios { + arg := libvirtxml.DomainQEMUCommandlineArg{ + Value: fmt.Sprintf("type=%d", smbiosType), + } + for key, value := range values { + arg.Value = fmt.Sprintf("%s,%s=%s", arg.Value, key, value) + } + qemuArgs = append(qemuArgs, libvirtxml.DomainQEMUCommandlineArg{Value: "-smbios"}, arg) + } + + d.QEMUCommandline = &libvirtxml.DomainQEMUCommandline{Args: qemuArgs} + + return nil + } +} diff --git a/debian.go b/debian.go new file mode 100644 index 0000000..aa956e9 --- /dev/null +++ b/debian.go @@ -0,0 +1,129 @@ +package main + +import ( + "fmt" + "io" + "mkvm/volumes" + "mkvm/volumes/pools" + "net/http" + "strings" + + "github.com/sirupsen/logrus" + "github.com/spf13/cobra" +) + +var ( + debianRelease string + debianPackages []string + + debianFilenameSuffix = map[pools.ImageFormat]string{ + pools.ImageFormatRaw: "-generic-amd64-daily.raw", + pools.ImageFormatQcow2: "-generic-amd64-daily.qcow2", + } +) + +var debianCmd = cobra.Command{ + Use: "debian vm-name", + Args: cobra.ExactArgs(1), + Run: func(cmd *cobra.Command, args []string) { + name := args[0] + + conn, pool, err := getLibvirtAndPool() + if err != nil { + logrus.WithError(err).Fatal("error connecting to libvirt") + return + } + defer conn.Close() + + diskImageURL, err := debianGetImageURL(debianRelease, pool.ImageFormat()) + if err != nil { + logrus.WithError(err).Fatal("error getting disk image URL") + } + + // download disk image + err = volumes.Create(conn, pool, argDiskSizeGB, diskImageURL, name) + if err != nil { + logrus.WithError(err).Fatal("error creating VM disk") + } + + // prepare cloudconfig and start the http server + bind, err := getServerListenAddress(conn) + if err != nil { + logrus.WithError(err).Fatal("error getting server bind address") + } + + serverURL := fmt.Sprintf("http://%s", bind) + + err = buildCloudConfig(name, fmt.Sprintf("%s/phone-home", serverURL)) + if err != nil { + logrus.WithError(err).Fatal("error building cloud config") + } + + go runHTTPServer(bind) + + smbios := map[int]map[string]string{1: {"serial": fmt.Sprintf("ds=nocloud-net;s=%s/", serverURL)}} + + // create domain + err = createDomain(conn, pool, name, setSMBIOS(smbios)) + if err != nil { + logrus.WithError(err).Fatal("error creating domain") + } + + wg.Add(1) + + logrus.Info("waiting for VM to finish provisioning") + wg.Wait() + }, +} + +func init() { + debianCmd.Flags().StringVarP(&debianRelease, "release", "r", "bookworm", "debian release to install. Options: bookworm (default), trixie, sid") + debianCmd.Flags().StringArrayVarP(&debianPackages, "packages", "p", nil, "apt packages to install") + + rootCmd.AddCommand(&debianCmd) +} + +func debianGetImageURL(release string, format pools.ImageFormat) (string, error) { + imageSuffix, ok := debianFilenameSuffix[format] + if !ok { + return "", fmt.Errorf("unexpected image format %s from storage pool", format) + } + + diskImageURLPrefix := fmt.Sprintf("https://cloud.debian.org/images/cloud/%s/daily/latest", release) + + // find image URL + hash + shaURL := fmt.Sprintf("%s/SHA512SUMS", diskImageURLPrefix) + shaResp, err := http.Get(shaURL) + if err != nil { + return "", err + } + defer shaResp.Body.Close() + + shas, err := io.ReadAll(shaResp.Body) + if err != nil { + return "", err + } + + if shaResp.StatusCode != http.StatusOK { + logrus.WithFields(logrus.Fields{ + "status": shaResp.Status, + "url": shaURL, + "resp": string(shas), + }).Fatal("failed to get image hash") + } + + for _, line := range strings.Split(string(shas), "\n") { + hash, filename, ok := strings.Cut(strings.TrimSpace(line), " ") + if !ok { + continue + } + + if !strings.HasSuffix(filename, imageSuffix) { + continue + } + + return fmt.Sprintf("%s/%s#hash=sha512:%s", diskImageURLPrefix, filename, hash), nil + } + + return "", fmt.Errorf("unable to find hash of image in %s", shaURL) +} diff --git a/http.go b/http.go index 72ca96a..fe0a7e3 100644 --- a/http.go +++ b/http.go @@ -49,14 +49,7 @@ func (httphandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } case http.MethodPost: - body, err := io.ReadAll(r.Body) - if err != nil { - log.WithError(err).Error("error reading body") - } - r.Body.Close() - - log.Debug(string(body)) - + printRequestBody(r) log.Info("VM booted") w.WriteHeader(http.StatusNoContent) wg.Done() @@ -68,7 +61,36 @@ func (httphandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } -func buildCloudConfig(i int, name string, phoneHomeURL string) error { +func buildCloudConfig(name string, phoneHomeURL string) error { + userdata, err := yaml.Marshal(cloudinit.UserData{ + Packages: argPackages, + SSHAuthorizedKeys: argSSHKeys, + PhoneHome: cloudinit.PhoneHome{ + URL: phoneHomeURL, + Post: []string{"pub_key_dsa", "pub_key_rsa", "pub_key_ed25519", "fqdn"}, + }, + }) + if err != nil { + return err + } + + httpStaticContent["/user-data"] = append([]byte("#cloud-config\n"), userdata...) + + metadata, err := yaml.Marshal(cloudinit.MetaData{ + InstanceID: name, + LocalHostname: name, + }) + if err != nil { + return err + } + httpStaticContent["/meta-data"] = metadata + + httpStaticContent["/vendor-data"] = nil + + return nil +} + +func buildCloudConfigPrefix(prefix string, name string, phoneHomeURL string) error { userdata, err := yaml.Marshal(cloudinit.UserData{ Packages: argPackages, SSHAuthorizedKeys: argSSHKeys, @@ -81,7 +103,7 @@ func buildCloudConfig(i int, name string, phoneHomeURL string) error { return err } - httpStaticContent[fmt.Sprintf("/%d/user-data", i)] = append([]byte("#cloud-config\n"), userdata...) + httpStaticContent[fmt.Sprintf("/%s/user-data", prefix)] = append([]byte("#cloud-config\n"), userdata...) metadata, err := yaml.Marshal(cloudinit.MetaData{ InstanceID: name, @@ -90,9 +112,43 @@ func buildCloudConfig(i int, name string, phoneHomeURL string) error { if err != nil { return err } - httpStaticContent[fmt.Sprintf("/%d/meta-data", i)] = metadata + httpStaticContent[fmt.Sprintf("/%s/meta-data", prefix)] = metadata - httpStaticContent[fmt.Sprintf("/%d/vendor-data", i)] = nil + httpStaticContent[fmt.Sprintf("/%s/vendor-data", prefix)] = nil return nil } + +func printRequestBody(r *http.Request) { + log := logrus.WithFields(logrus.Fields{ + "method": r.Method, + "path": r.URL, + "remote_addr": r.RemoteAddr, + }) + + switch r.Header.Get("Content-Type") { + case "application/x-www-form-urlencoded": + if err := r.ParseForm(); err != nil { + logrus.WithError(err).Error("error parsing request body") + return + } + fields := logrus.Fields{} + for k, v := range r.Form { + if len(v) == 1 { + fields[k] = v[0] + } else { + fields[k] = v + } + } + log.WithFields(fields).Debug("form request body") + default: + body, err := io.ReadAll(r.Body) + if err != nil { + log.WithError(err).Error("error reading body") + } + r.Body.Close() + + log.Debug(string(body)) + } + +} diff --git a/main.go b/main.go index 09e36e4..35b5fe6 100644 --- a/main.go +++ b/main.go @@ -24,13 +24,14 @@ var ( Use: "mkvm name [name [name]]", Short: "create virtual machine(s) via libvirt", Args: cobra.MinimumNArgs(1), - Run: func(cmd *cobra.Command, args []string) { + PersistentPreRun: func(cmd *cobra.Command, args []string) { logrus.SetLevel(logrus.DebugLevel) if err := config.Load(); err != nil { logrus.WithError(err).Fatal("error loading config") } - + }, + Run: func(cmd *cobra.Command, args []string) { for _, u := range argSSHKeyURLs { if err := downloadSSHKeys(u); err != nil { logrus.WithError(err).WithField("url", u).Fatal("error downloading SSH keys") @@ -96,7 +97,7 @@ var ( for i, name := range args { metadataURL := fmt.Sprintf("http://%s/%d", serverBind, i) - if err := buildCloudConfig(i, name, metadataURL); err != nil { + if err := buildCloudConfigPrefix(fmt.Sprint(i), name, metadataURL); err != nil { logrus.WithError(err).WithField("vm", name).Error("error building cloudconfig for vm") continue } diff --git a/mkvm.go b/mkvm.go index b7c05cc..629db46 100644 --- a/mkvm.go +++ b/mkvm.go @@ -7,6 +7,7 @@ import ( "libvirt.org/go/libvirt" "libvirt.org/go/libvirtxml" + "mkvm/config" "mkvm/volumes" "mkvm/volumes/pools" ) @@ -15,7 +16,7 @@ func mkvm(conn *libvirt.Connect, metadataURL string, name string) error { logger := logrus.WithField("vm", name) logger.Debug("creating vm") - pool, err := pools.GetPool(conn) + pool, err := pools.GetPool(conn, config.C.StoragePool) if err != nil { return err } diff --git a/volumes/pools/storage.go b/volumes/pools/storage.go index 713e1f5..193075b 100644 --- a/volumes/pools/storage.go +++ b/volumes/pools/storage.go @@ -6,8 +6,6 @@ package pools import ( "libvirt.org/go/libvirt" "libvirt.org/go/libvirtxml" - - "mkvm/config" ) type StoragePoolType string @@ -29,8 +27,8 @@ type StoragePool interface { var drivers = map[string]func(*libvirt.StoragePool) (StoragePool, error){} // GetPool retrieves the configured Storage Pool from libvirt -func GetPool(conn *libvirt.Connect) (StoragePool, error) { - pool, err := conn.LookupStoragePoolByName(config.C.StoragePool) +func GetPool(conn *libvirt.Connect, name string) (StoragePool, error) { + pool, err := conn.LookupStoragePoolByName(name) if err != nil { return nil, err }