socktest.c

#include <stdio.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <errno.h>
#include <string.h>
#include <stdlib.h>
#include <sys/types.h>
#include <unistd.h>
#include <fcntl.h>
#include <poll.h>

#define TIMEOUT_SECS 2
#define FORMAT "OPTIONS * HTTP/1.1\r\nHost: %s\r\nConnection: close\r\n\r\n"
#define MESSAGE_LEN 128
#define MAXLINE 1024
#define POLL_TIMEOUT_MSECS 20

char message[MESSAGE_LEN];
char line[MAXLINE];
int message_len;
int message_pos=0;
typedef enum { initial, connecting, writing, reading, done, error } connect_state;
connect_state state;
/**
 * Set the timeout on a socket
 * @param sockfd the socket file descriptor
 * @param secs the number of seconds to timeout after
 * @return 1 if it worked, else 0
 */
static int socket_set_timeout( int sockfd, int secs )
{
    int res;
    struct timeval tv;
    tv.tv_sec = secs;
    tv.tv_usec = 0;
    res = setsockopt( sockfd, SOL_SOCKET, SO_RCVTIMEO, &tv,sizeof(tv) );
    if ( res == -1 )
        printf("failed to set timeout on socket\n");
    return res==0;
}
/**
 * Write n bytes 
 * @param fd the descriptor to write to
 * @param vptr the data to write
 * @param n its length
 * @return -1 on error, else number of bytes written
 */
static ssize_t writen( int fd, const void *vptr, size_t n )
{
    size_t nleft;
    ssize_t nwritten;
    const char *ptr;
    ptr = vptr;
    nleft = n;
    while ( nleft > 0 )
    {
        //printf("about to call write with nleft=%d\n",(int)nleft);
        if ((nwritten = write(fd,ptr,nleft)) <= 0 )
        {
            //printf("nwritten=%d errno=%d\n",(int)nwritten,errno);
            if ( errno == EINTR )
                nwritten = 0;
            else
                return -1;
        }
        nleft -= nwritten;
        ptr += nwritten;
    }
    return n;
}
/**
 * Read from a socket
 * @param sock the socket to read from
 * @return number of bytes read  or -1 on error
 */
static int readn( int sock )
{
    int n,total = 0;
    for ( ; ; )
    {
        n=read( sock, line, MAXLINE );
        if ( n < 0 )
        {
            total = -1;
            printf( "failed to read. err=%s socket=%d\n",strerror(errno),sock);
            break;
        }
        else if ( n == 0 )
        {
            // just finished reading
            break;
        }
        else
            total += n;
    }
    return total;
}
/**
 * Connect the given socket to the host and its destination port
 * @param sock the socket to connect
 * @param host the host ip adddress
 * @param port the host port
 * @return 0 if it fails, else 1
 */
int do_connect( int sock, char *host, char *port )
{
    struct sockaddr_in addr;
    /* clear addr structure first */
    memset( &addr, 0, sizeof(addr) );
    /* reuse addr structure to connect to host and port */
    int res = inet_pton( AF_INET,host,&addr.sin_addr);
    if ( res == 1 )
    {
        /* port number must be in network byte order */
        addr.sin_port = htons(atoi(port));
        /* establish TCP connection via handshake (SYN,SYN-ACK,ACK) */
        res = connect(sock,(const struct sockaddr *)&addr, sizeof(addr));
        if ( res == 0 )
        {
            printf("connected successfully to %s on port %s\n",host,port);
            return 1;
        }
        else if ( res == -1 )
        {
            if ( errno == EINTR || errno == EINPROGRESS )
            {
                state = connecting;
                return 0;
            }
            else 
            {
                state = error;
                printf("couldn't connect to %s on port %s\n",host,port);
            }
        }
    }
    else
        printf("inet_pton failed: %s\n",strerror(errno) );
    state = error;
    return 0;
}
/**
 * In non-blocking mode we must poll the socket to see if the connection 
 * is established
 * @param sock the socket to test for connection
 * @return 1 if it is now connected, else 0
 */
int try_connect( int sock )
{
    struct pollfd fds[1];
    fds[0].fd = sock;
    fds[0].events = POLLWRBAND | POLLOUT;

    int res = poll( fds, 1, POLL_TIMEOUT_MSECS );
    if ( res == 1 )
    {
        return 1;
    }
    else if ( res == -1 )
    {
        state = error;
    }
    return 0;
}
/**
 * Set up and bind a socket to a local ip and port
 * @param blocking use the non-blocking mode
 * @param localip the local ip address to bind to
 * @return the bound socket if successful, else -1
 */
int tcp_bind( int blocking, char *localip )
{
    /* create an IPv4 TCP socket */
    int sock = socket( AF_INET, SOCK_STREAM, 0 );
    if ( sock != -1 )
    {
        /* bind socket to localhost */
        struct sockaddr_in addr;
        memset( &addr, 0, sizeof(addr) );
        int res=0,flags;
        if ( !blocking )
        {
            /* get existing socket flags */
              flags = fcntl (sock, F_GETFL, 0 );
              /* switch socket to non-blocking mode */
            res = fcntl( sock, F_SETFL, flags | O_NONBLOCK );
        }
        if ( res != -1 )
        {
            /* set timeout for read/write */
            if ( !socket_set_timeout(sock,TIMEOUT_SECS) )
            {
                close( sock );
                return -1;
            }
            addr.sin_family = AF_INET;
            /* use a random port as the socket's source port */
            addr.sin_port = 0;
            /* load the address of localhost as the socket's source address */
            int res = inet_pton( AF_INET, localip, &addr.sin_addr );
            if ( res != 1 )
                printf("inet_pton error %s\n",strerror(errno));
            else
            {
                res = bind( sock, (const struct sockaddr *)&addr,sizeof(addr));
                if ( res != -1 )
                {
                    printf("bound socket %d to %s \n",sock,localip);
                    return sock;
                }
                else
                    printf("failed to bind sock %d to %s. err=%s\n",
                        sock,localip,strerror(errno));
            }
            close( sock );
            printf( "closed b-socket %d on error\n",sock);
        }
        else
            printf("failed to set socket to non-blocking mode\n");
    }
    else
        printf("couldn't open nb-socket. error=%s\n",strerror(errno));
    return -1;
}
/**
 * Send a blocking message to the dest_ip from localhost to the 
 * given dest_ip and dest_port
 * @param argv program name, dest-ip and dest-port
 * @return 1 on success, else 0
 */
int sendn( char **argv )
{
    int res = 0;
    int sock = tcp_bind( 1, "127.0.0.1" );
    if ( sock != -1 )
    {
        if ( do_connect(sock,argv[1],argv[2]) )
        {
            /* write message to sock */
            int sent = writen( sock, message, message_len );
            if ( sent == message_len )
            {
                printf( "successfully sent %d bytes to %s, port %s\n",
                    sent,argv[1],argv[2]);
                int rcvd = readn( sock );
                if ( rcvd > 0 )
                {
                    printf("received %d bytes from %s\n",rcvd,argv[1]);
                    res = 1;
                }
                else
                    printf("failed to read anything or error\n");
            }
            else
                printf("tried to send %d via blocking socket but only sent %d\n",
                    message_len,sent);
        }
        close( sock );
    }
    return res;
}
/**
 * Test for writing and if possible, write
 */
int try_writen( int sock )
{
     ssize_t nwritten;
    struct pollfd fds[1];
    fds[0].events |= POLLOUT;
    fds[0].events |= POLLWRBAND;
    int res = poll(fds, 1, POLL_TIMEOUT_MSECS);
    if ( res > 0 )
    {
        int nleft = message_len-message_pos;
        while ( nleft > 0 )
        {
            if ((nwritten = write(sock,&message[message_pos],nleft)) <= 0 )
            {
                if ( errno != EINTR || errno != EAGAIN )
                    state = error;
                return 0;
            }
            nleft -= nwritten;
            if ( nleft > 0 )
                message_pos += nwritten;
            else
            {
                message_pos = 0;
                return 1;
            }
        }
    }
    else if ( res < 0 )
    {
        printf( "error: %s\n", strerror(errno) );
        state = error;
    }
    return 0;
}
/**
 * Test for reading on socket and if you can, reading
 * @param sock the socket to read from
 */
int try_readn( int sock )
{
    int n,total = 0;    
    struct pollfd fds[1];
    fds[0].events |= POLLIN;
    fds[0].events |= POLLPRI;
    int res = poll(fds, 1, POLL_TIMEOUT_MSECS);
    if ( res > 0 )
    {
        for ( ; ; )
        {
            n=read( sock, line, MAXLINE );
            if ( n < 0 )
            {
                if ( errno != EINTR || errno != EAGAIN )
                {
                    state = error;
                        printf("error: %s\n",strerror(errno));
                }
                return 0;
            }
            else if ( n == 0 )
            {
                // just finished reading
                break;
            }
            else
                total += n;
        }
    }
    else if ( res < 0 )
    {
        printf("error: %s\n",strerror(errno));
        state = error;
    }
    return total;
}
/**
 * Send non-blocking
 * @param argv the argument list
 */
int sendnb( char **argv )
{
    int res;
    int sock = tcp_bind( 0, "127.0.0.1" );
    if ( sock != -1 )
    {
        do
        {
            switch ( state )
            {
                case initial:
                    res = do_connect( sock, argv[1], argv[2] );
                    if ( res )
                        state = writing;
                    break;
                case connecting:
                    res = try_connect( sock );
                    if ( res )
                        state = writing;
                    break;
                case writing:
                    res = try_writen( sock );
                    if ( res )
                        state = reading;
                    break;
                case reading:
                    res = try_readn( sock );
                    if ( res )
                        state = done;
                    break;
            }
        } 
        while ( state != done && state != error );
        close( sock );
        if ( state == done )
            return 1;
    }
    return 0;
}
/*
 * test program to open a blocking socket, then close it 
 * and reopen it on the same ip as non-blocking. Each time 
 * send and receive data 
*/
int main( int argc, char **argv )
{
    int res = 0;
    if ( argc != 3 )
    {
        printf("usage: ./socktest dest_ip port\n");
        exit(0);
    }
    state = initial;
    message_len = snprintf( message, MESSAGE_LEN, FORMAT, argv[1] );
    if ( message_len > 0 )
    {
        /* send in blocking and non-blocking modes */
        if (res=sendn(argv) )
        {
            printf("blocking send succeeded\n");
            res = sendnb(argv);
            if ( res )
                printf("sent non-blocking successfully\n");
        }
    }
    return 0;
}

No comments:

Post a Comment